Moving Source from connect point to HostId in MulticastHandling

Change-Id: Ie8f678e150b7ee388680b8d8f27df0bce60ec01f
diff --git a/apps/mcast/impl/src/main/java/org/onosproject/mcast/impl/MulticastRouteManager.java b/apps/mcast/impl/src/main/java/org/onosproject/mcast/impl/MulticastRouteManager.java
index 055d8de..9f15b6c 100644
--- a/apps/mcast/impl/src/main/java/org/onosproject/mcast/impl/MulticastRouteManager.java
+++ b/apps/mcast/impl/src/main/java/org/onosproject/mcast/impl/MulticastRouteManager.java
@@ -113,16 +113,40 @@
         final Optional<IpAddress> source = Optional.ofNullable(sourceIp);
         return store.getRoutes().stream()
                 .filter(route -> route.group().equals(groupIp) &&
-                                Objects.equal(route.source(), source))
+                        Objects.equal(route.source(), source))
                 .collect(Collectors.toSet());
     }
 
     @Override
-    public void addSources(McastRoute route, Set<ConnectPoint> connectPoints) {
+    public void addSource(McastRoute route, HostId source) {
         checkNotNull(route, "Route cannot be null");
-        checkNotNull(connectPoints, "Source cannot be null");
+        checkNotNull(source, "Source cannot be null");
         if (checkRoute(route)) {
-            store.storeSources(route, connectPoints);
+            Set<ConnectPoint> sources = new HashSet<>();
+            Host host = hostService.getHost(source);
+            if (host != null) {
+                sources.addAll(host.locations());
+            }
+            store.storeSource(route, source, sources);
+        }
+    }
+
+    @Override
+    public void addSources(McastRoute route, HostId hostId, Set<ConnectPoint> connectPoints) {
+        checkNotNull(route, "Route cannot be null");
+        checkNotNull(hostId, "HostId cannot be null");
+        checkNotNull(connectPoints, "Sources cannot be null");
+        if (checkRoute(route)) {
+            store.storeSource(route, hostId, connectPoints);
+        }
+    }
+
+    @Override
+    public void addSources(McastRoute route, Set<ConnectPoint> sources) {
+        checkNotNull(route, "Route cannot be null");
+        checkNotNull(sources, "sources cannot be null");
+        if (checkRoute(route)) {
+            store.storeSources(route, sources);
         }
     }
 
@@ -135,11 +159,11 @@
     }
 
     @Override
-    public void removeSources(McastRoute route, Set<ConnectPoint> sources) {
+    public void removeSource(McastRoute route, HostId source) {
         checkNotNull(route, "Route cannot be null");
-        checkNotNull(sources, "Source cannot be null");
+        checkNotNull(source, "Source cannot be null");
         if (checkRoute(route)) {
-            store.removeSources(route, sources);
+            store.removeSource(route, source);
         }
     }
 
@@ -149,8 +173,7 @@
             Set<ConnectPoint> sinks = new HashSet<>();
             Host host = hostService.getHost(hostId);
             if (host != null) {
-                host.locations().forEach(hostLocation -> sinks.add(
-                        ConnectPoint.deviceConnectPoint(hostLocation.deviceId() + "/" + hostLocation.port())));
+                sinks.addAll(host.locations());
             }
             store.addSink(route, hostId, sinks);
         }
@@ -166,7 +189,7 @@
     }
 
     @Override
-    public void addSink(McastRoute route, Set<ConnectPoint> sinks) {
+    public void addSinks(McastRoute route, Set<ConnectPoint> sinks) {
         checkNotNull(route, "Route cannot be null");
         checkNotNull(sinks, "Sinks cannot be null");
         if (checkRoute(route)) {
@@ -212,6 +235,12 @@
     }
 
     @Override
+    public Set<ConnectPoint> sources(McastRoute route, HostId hostId) {
+        checkNotNull(route, "Route cannot be null");
+        return checkRoute(route) ? store.sourcesFor(route, hostId) : ImmutableSet.of();
+    }
+
+    @Override
     public Set<ConnectPoint> sinks(McastRoute route) {
         checkNotNull(route, "Route cannot be null");
         return checkRoute(route) ? store.sinksFor(route) : ImmutableSet.of();
@@ -252,33 +281,55 @@
         public void event(HostEvent event) {
             HostId hostId = event.subject().id();
             log.debug("Host event: {}", event);
+            Set<McastRoute> routesForSource = routesForSource(hostId);
+            Set<McastRoute> routesForSink = routesForSink(hostId);
             switch (event.type()) {
                 case HOST_ADDED:
                     //the host is added, if it already comes with some locations let's use them
-                    eventAddSinks(hostId, event.subject().locations());
+                    if (!routesForSource.isEmpty()) {
+                        eventAddSources(hostId, event.subject().locations(), routesForSource);
+                    }
+                    if (!routesForSink.isEmpty()) {
+                        eventAddSinks(hostId, event.subject().locations(), routesForSink);
+                    }
                     break;
                 case HOST_MOVED:
                     //both subjects must be null or the system is in an incoherent state
                     if ((event.prevSubject() != null && event.subject() != null)) {
                         //we compute the difference between old locations and new ones and remove the previous
-                        Set<HostLocation> removedSinks = Sets.difference(event.prevSubject().locations(),
+                        Set<HostLocation> removedConnectPoint = Sets.difference(event.prevSubject().locations(),
                                 event.subject().locations()).immutableCopy();
-                        if (!removedSinks.isEmpty()) {
-                            eventRemoveSinks(hostId, removedSinks);
+                        if (!removedConnectPoint.isEmpty()) {
+                            if (!routesForSource.isEmpty()) {
+                                eventRemoveSources(hostId, removedConnectPoint, routesForSource);
+                            }
+                            if (!routesForSink.isEmpty()) {
+                                eventRemoveSinks(hostId, removedConnectPoint, routesForSink);
+                            }
                         }
-                        Set<HostLocation> addedSinks = Sets.difference(event.subject().locations(),
+                        Set<HostLocation> addedConnectPoints = Sets.difference(event.subject().locations(),
                                 event.prevSubject().locations()).immutableCopy();
                         //if the host now has some new locations we add them to the sinks set
-                        if (!addedSinks.isEmpty()) {
-                            eventAddSinks(hostId, addedSinks);
+                        if (!addedConnectPoints.isEmpty()) {
+                            if (!routesForSource.isEmpty()) {
+                                eventAddSources(hostId, addedConnectPoints, routesForSource);
+                            }
+                            if (!routesForSink.isEmpty()) {
+                                eventAddSinks(hostId, addedConnectPoints, routesForSink);
+                            }
                         }
                     }
                     break;
                 case HOST_REMOVED:
-                    // Removing all the sinks for that specific host
+                    // Removing all the connect points for that specific host
                     // even if the locations are 0 we keep
                     // the host information in the route in case it shows up again
-                    eventRemoveSinks(event.subject().id(), event.subject().locations());
+                    if (!routesForSource.isEmpty()) {
+                        eventRemoveSources(hostId, event.subject().locations(), routesForSource);
+                    }
+                    if (!routesForSink.isEmpty()) {
+                        eventRemoveSinks(hostId, event.subject().locations(), routesForSink);
+                    }
                     break;
                 case HOST_UPDATED:
                 default:
@@ -287,25 +338,52 @@
         }
     }
 
-    //Adds sinks for a given host event
-    private void eventRemoveSinks(HostId hostId, Set<HostLocation> removedSinks) {
+    //Finds the route for which a host is source
+    private Set<McastRoute> routesForSource(HostId hostId) {
+        // Filter by host id
+        return store.getRoutes().stream().filter(mcastRoute -> store.getRouteData(mcastRoute)
+                .sources().containsKey(hostId)).collect(Collectors.toSet());
+    }
+
+    //Finds the route for which a host is sink
+    private Set<McastRoute> routesForSink(HostId hostId) {
+        return store.getRoutes().stream().filter(mcastRoute -> store.getRouteData(mcastRoute)
+                .sinks().containsKey(hostId)).collect(Collectors.toSet());
+    }
+
+    //Removes sources for a given host event
+    private void eventRemoveSources(HostId hostId, Set<HostLocation> removedSources, Set<McastRoute> routesForSource) {
+        Set<ConnectPoint> sources = new HashSet<>();
+        // Build sink using host location
+        sources.addAll(removedSources);
+        // Remove from each route the provided sinks
+        routesForSource.forEach(route -> store.removeSources(route, hostId, sources));
+    }
+
+    //Adds the sources for a given host event
+    private void eventAddSources(HostId hostId, Set<HostLocation> addedSources, Set<McastRoute> routesForSource) {
+        Set<ConnectPoint> sources = new HashSet<>();
+        // Build source using host location
+        sources.addAll(addedSources);
+        // Add to each route the provided sources
+        routesForSource.forEach(route -> store.storeSource(route, hostId, sources));
+    }
+
+    //Remove sinks for a given host event
+    private void eventRemoveSinks(HostId hostId, Set<HostLocation> removedSinks, Set<McastRoute> routesForSinks) {
         Set<ConnectPoint> sinks = new HashSet<>();
         // Build sink using host location
         sinks.addAll(removedSinks);
-        // Filter by host id and then remove from each route the provided sinks
-        store.getRoutes().stream().filter(mcastRoute -> store.getRouteData(mcastRoute)
-                .sinks().get(hostId) != null)
-                .forEach(route -> store.removeSinks(route, hostId, sinks));
+        // Remove from each route the provided sinks
+        routesForSinks.forEach(route -> store.removeSinks(route, hostId, sinks));
     }
 
-    //Removes the sinks for a given host event
-    private void eventAddSinks(HostId hostId, Set<HostLocation> addedSinks) {
+    //Adds the sinks for a given host event
+    private void eventAddSinks(HostId hostId, Set<HostLocation> addedSinks, Set<McastRoute> routesForSinks) {
         Set<ConnectPoint> sinks = new HashSet<>();
         // Build sink using host location
         sinks.addAll(addedSinks);
-        // Filter by host id and then add to each route the provided sinks
-        store.getRoutes().stream().filter(mcastRoute -> store.getRouteData(mcastRoute)
-                .sinks().get(hostId) != null)
-                .forEach(route -> store.addSink(route, hostId, sinks));
+        // Add to each route the provided sinks
+        routesForSinks.forEach(route -> store.addSink(route, hostId, sinks));
     }
 }