Fix HOST event handling in MulticastRouteManager

Change-Id: I721470bd1879c1dc252346a0f4f085ca80f54156
(cherry picked from commit beea3e38fac2f6d763c62bb28bd7256b95bebd9c)
diff --git a/apps/mcast/impl/src/main/java/org/onosproject/mcast/impl/MulticastRouteManager.java b/apps/mcast/impl/src/main/java/org/onosproject/mcast/impl/MulticastRouteManager.java
index 6ed8b4b..43b145c 100644
--- a/apps/mcast/impl/src/main/java/org/onosproject/mcast/impl/MulticastRouteManager.java
+++ b/apps/mcast/impl/src/main/java/org/onosproject/mcast/impl/MulticastRouteManager.java
@@ -15,7 +15,9 @@
  */
 package org.onosproject.mcast.impl;
 
+import com.google.common.base.Objects;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Sets;
 import org.apache.felix.scr.annotations.Activate;
 import org.apache.felix.scr.annotations.Component;
 import org.apache.felix.scr.annotations.Deactivate;
@@ -34,6 +36,7 @@
 import org.onosproject.net.ConnectPoint;
 import org.onosproject.net.Host;
 import org.onosproject.net.HostId;
+import org.onosproject.net.HostLocation;
 import org.onosproject.net.host.HostEvent;
 import org.onosproject.net.host.HostListener;
 import org.onosproject.net.host.HostService;
@@ -42,6 +45,7 @@
 import java.util.HashSet;
 import java.util.Optional;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 import static com.google.common.base.Preconditions.checkNotNull;
 import static org.slf4j.LoggerFactory.getLogger;
@@ -104,11 +108,13 @@
     }
 
     @Override
-    public Optional<McastRoute> getRoute(IpAddress groupIp, IpAddress sourceIp) {
-        return store.getRoutes().stream().filter(route ->
-                route.group().equals(groupIp) &&
-                        route.source().isPresent() &&
-                        route.source().get().equals(sourceIp)).findAny();
+    public Set<McastRoute> getRoute(IpAddress groupIp, IpAddress sourceIp) {
+        // Let's transform it into an optional
+        final Optional<IpAddress> source = Optional.ofNullable(sourceIp);
+        return store.getRoutes().stream()
+                .filter(route -> route.group().equals(groupIp) &&
+                                Objects.equal(route.source(), source))
+                .collect(Collectors.toSet());
     }
 
     @Override
@@ -178,19 +184,10 @@
     }
 
     @Override
-    public void removeSinks(McastRoute route, HostId hostId, Set<ConnectPoint> connectPoints) {
-        checkNotNull(route, "Route cannot be null");
-        if (checkRoute(route)) {
-            store.removeSinks(route, hostId, connectPoints);
-        }
-
-    }
-
-    @Override
     public void removeSinks(McastRoute route, Set<ConnectPoint> connectPoints) {
         checkNotNull(route, "Route cannot be null");
         if (checkRoute(route)) {
-            store.removeSinks(route, HostId.NONE, connectPoints);
+            store.removeSinks(route, connectPoints);
         }
     }
 
@@ -227,7 +224,7 @@
     private class InternalMcastStoreDelegate implements McastStoreDelegate {
         @Override
         public void notify(McastEvent event) {
-            log.debug("Event: {}", event);
+            log.debug("Notify event: {}", event);
             post(event);
         }
     }
@@ -246,39 +243,61 @@
         @Override
         public void event(HostEvent event) {
             HostId hostId = event.subject().id();
-            Set<ConnectPoint> sinks = new HashSet<>();
-            log.debug("{} event", event);
-            //FIXME ther must be a better way
-            event.subject().locations().forEach(hostLocation -> sinks.add(
-                    ConnectPoint.deviceConnectPoint(hostLocation.deviceId() + "/" + hostLocation.port())));
+            log.debug("Host event: {}", event);
             switch (event.type()) {
                 case HOST_ADDED:
-                case HOST_UPDATED:
+                    //the host is added, if it already comes with some locations let's use them
+                    eventAddSinks(hostId, event.subject().locations());
+                    break;
                 case HOST_MOVED:
-                    if ((event.prevSubject() == null && event.subject() != null)
-                            || (event.prevSubject().locations().size() > event.subject().locations().size())) {
-                        store.getRoutes().stream().filter(mcastRoute -> {
-                            return store.getRouteData(mcastRoute).sinks().get(hostId) != null;
-                        }).forEach(route -> {
-                            store.removeSinks(route, hostId, sinks);
-                        });
-                    } else if (event.prevSubject().locations().size() < event.subject().locations().size()) {
-                        store.getRoutes().stream().filter(mcastRoute -> {
-                            return store.getRouteData(mcastRoute).sinks().get(hostId) != null;
-                        }).forEach(route -> {
-                            store.addSink(route, hostId, sinks);
-                        });
+                    //both subjects must be null or the system is in an incoherent state
+                    if ((event.prevSubject() != null && event.subject() != null)) {
+                        //we compute the difference between old locations and new ones and remove the previous
+                        Set<HostLocation> removedSinks = Sets.difference(event.prevSubject().locations(),
+                                event.subject().locations()).immutableCopy();
+                        if (!removedSinks.isEmpty()) {
+                            eventRemoveSinks(hostId, removedSinks);
+                        }
+                        Set<HostLocation> addedSinks = Sets.difference(event.subject().locations(),
+                                event.prevSubject().locations()).immutableCopy();
+                        //if the host now has some new locations we add them to the sinks set
+                        if (!addedSinks.isEmpty()) {
+                            eventAddSinks(hostId, addedSinks);
+                        }
                     }
                     break;
                 case HOST_REMOVED:
-                    store.getRoutes().stream().filter(mcastRoute -> {
-                        return store.getRouteData(mcastRoute).sinks().get(hostId) != null;
-                    }).forEach(route -> {
-                        store.removeSink(route, hostId);
-                    });
+                    // Removing all the sinks for that specific host
+                    // even if the locations are 0 we keep
+                    // the host information in the route in case it shows up again
+                    eventRemoveSinks(event.subject().id(), event.subject().locations());
+                    break;
+                case HOST_UPDATED:
                 default:
-                    log.debug("Host event {} not supported", event.type());
+                    log.debug("Host event {} not handled", event.type());
             }
         }
     }
+
+    //Adds sinks for a given host event
+    private void eventRemoveSinks(HostId hostId, Set<HostLocation> removedSinks) {
+        Set<ConnectPoint> sinks = new HashSet<>();
+        // Build sink using host location
+        sinks.addAll(removedSinks);
+        // Filter by host id and then remove from each route the provided sinks
+        store.getRoutes().stream().filter(mcastRoute -> store.getRouteData(mcastRoute)
+                .sinks().get(hostId) != null)
+                .forEach(route -> store.removeSinks(route, hostId, sinks));
+    }
+
+    //Removes the sinks for a given host event
+    private void eventAddSinks(HostId hostId, Set<HostLocation> addedSinks) {
+        Set<ConnectPoint> sinks = new HashSet<>();
+        // Build sink using host location
+        sinks.addAll(addedSinks);
+        // Filter by host id and then add to each route the provided sinks
+        store.getRoutes().stream().filter(mcastRoute -> store.getRouteData(mcastRoute)
+                .sinks().get(hostId) != null)
+                .forEach(route -> store.addSink(route, hostId, sinks));
+    }
 }