Refactors McastHandler to optimize network failures

Changes:
- Introduces a paths store
- Optimizes SINKS_ADDED, SINKS_REMOVED
- Uses SINKS_REMOVED to handle egress failures
- Optimizes SOURCES_ADDED, SOURCES_REMOVED
- Leverages SOURCES_REMOVED to handle ingress failures
- Optimizes infra failures

Change-Id: I16d264f58d6fe11cfce4a546f7b4ab82a9fcc21b
diff --git a/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
index c4e1ad4..396b998 100644
--- a/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
+++ b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
@@ -16,6 +16,7 @@
 
 package org.onosproject.segmentrouting.mcast;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
@@ -34,6 +35,7 @@
 import org.onosproject.net.DeviceId;
 import org.onosproject.net.HostId;
 import org.onosproject.net.Link;
+import org.onosproject.net.Path;
 import org.onosproject.net.PortNumber;
 import org.onosproject.net.config.basics.McastConfig;
 import org.onosproject.net.flow.DefaultTrafficSelector;
@@ -50,6 +52,10 @@
 import org.onosproject.net.flowobjective.ForwardingObjective;
 import org.onosproject.net.flowobjective.NextObjective;
 import org.onosproject.net.flowobjective.ObjectiveContext;
+import org.onosproject.net.topology.LinkWeigher;
+import org.onosproject.net.topology.Topology;
+import org.onosproject.net.topology.TopologyService;
+import org.onosproject.segmentrouting.SRLinkWeigher;
 import org.onosproject.segmentrouting.SegmentRoutingManager;
 import org.onosproject.segmentrouting.SegmentRoutingService;
 import org.onosproject.segmentrouting.config.DeviceConfigNotFoundException;
@@ -57,8 +63,10 @@
 import org.slf4j.Logger;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -66,11 +74,11 @@
  * Utility class for Multicast Handler.
  */
 class McastUtils {
-
     // Internal reference to the log
     private final Logger log;
-    // Internal reference to SR Manager
-    private SegmentRoutingManager srManager;
+    // Internal reference to SR Manager and topology service
+    private final SegmentRoutingManager srManager;
+    private final TopologyService topologyService;
     // Internal reference to the app id
     private ApplicationId coreAppId;
     // Hashing function for the multicast hasher
@@ -87,6 +95,7 @@
      */
     McastUtils(SegmentRoutingManager srManager, ApplicationId coreAppId, Logger log) {
         this.srManager = srManager;
+        this.topologyService = srManager.topologyService;
         this.coreAppId = coreAppId;
         this.log = log;
         this.mcastLeaderCache = Maps.newConcurrentMap();
@@ -95,7 +104,7 @@
     /**
      * Clean up when deactivating the application.
      */
-    public void terminate() {
+    void terminate() {
         mcastLeaderCache.clear();
     }
 
@@ -233,6 +242,8 @@
         return srManager.getDefaultInternalVlan();
     }
 
+
+
     /**
      * Gets sources connect points of given multicast group.
      *
@@ -477,64 +488,228 @@
     }
 
     /**
-     * Build recursively the mcast paths.
+     * Go through all the paths, looking for shared links to be used
+     * in the final path computation.
      *
-     * @param mcastNextObjStore mcast next obj store
-     * @param toVisit the node to visit
-     * @param visited the visited nodes
-     * @param mcastPaths the current mcast paths
-     * @param currentPath the current path
-     * @param mcastIp the group ip
-     * @param source the source
+     * @param egresses egress devices
+     * @param availablePaths all the available paths towards the egress
+     * @return shared links between egress devices
      */
-    void buildMcastPaths(Map<McastStoreKey, NextObjective> mcastNextObjStore,
-                                 DeviceId toVisit, Set<DeviceId> visited,
-                                 Map<ConnectPoint, List<ConnectPoint>> mcastPaths,
-                                 List<ConnectPoint> currentPath, IpAddress mcastIp,
-                                 ConnectPoint source) {
-        log.debug("Building Multicast paths recursively for {} - next device to visit is {}",
-                  mcastIp, toVisit);
-        // If we have visited the node to visit there is a loop
-        if (visited.contains(toVisit)) {
-            return;
-        }
-        // Visit next-hop
-        visited.add(toVisit);
-        VlanId assignedVlan = assignedVlan(toVisit.equals(source.deviceId()) ? source : null);
-        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, toVisit, assignedVlan);
-        // Looking for next-hops
-        if (mcastNextObjStore.containsKey(mcastStoreKey)) {
-            // Build egress connect points, get ports and build relative cps
-            NextObjective nextObjective = mcastNextObjStore.get(mcastStoreKey);
-            Set<PortNumber> outputPorts = getPorts(nextObjective.next());
-            ImmutableSet.Builder<ConnectPoint> cpBuilder = ImmutableSet.builder();
-            outputPorts.forEach(portNumber -> cpBuilder.add(new ConnectPoint(toVisit, portNumber)));
-            Set<ConnectPoint> egressPoints = cpBuilder.build();
-            Set<Link> egressLinks;
-            List<ConnectPoint> newCurrentPath;
-            Set<DeviceId> newVisited;
-            DeviceId newToVisit;
-            for (ConnectPoint egressPoint : egressPoints) {
-                egressLinks = srManager.linkService.getEgressLinks(egressPoint);
-                // If it does not have egress links, stop
-                if (egressLinks.isEmpty()) {
-                    // Add the connect points to the path
-                    newCurrentPath = Lists.newArrayList(currentPath);
-                    newCurrentPath.add(0, egressPoint);
-                    mcastPaths.put(egressPoint, newCurrentPath);
-                } else {
-                    newVisited = Sets.newHashSet(visited);
-                    // Iterate over the egress links for the next hops
-                    for (Link egressLink : egressLinks) {
-                        newToVisit = egressLink.dst().deviceId();
-                        newCurrentPath = Lists.newArrayList(currentPath);
-                        newCurrentPath.add(0, egressPoint);
-                        newCurrentPath.add(0, egressLink.dst());
-                        buildMcastPaths(mcastNextObjStore, newToVisit, newVisited, mcastPaths, newCurrentPath, mcastIp,
-                                source);
-                    }
-                }
+    private Set<Link> exploreMcastTree(Set<DeviceId> egresses,
+                                       Map<DeviceId, List<Path>> availablePaths) {
+        int minLength = Integer.MAX_VALUE;
+        int length;
+        List<Path> currentPaths;
+        // Verify the source can still reach all the egresses
+        for (DeviceId egress : egresses) {
+            // From the source we cannot reach all the sinks
+            // just continue and let's figure out after
+            currentPaths = availablePaths.get(egress);
+            if (currentPaths.isEmpty()) {
+                continue;
+            }
+            // Get the length of the first one available, update the min length
+            length = currentPaths.get(0).links().size();
+            if (length < minLength) {
+                minLength = length;
             }
         }
+        // If there are no paths
+        if (minLength == Integer.MAX_VALUE) {
+            return Collections.emptySet();
+        }
+        int index = 0;
+        Set<Link> sharedLinks = Sets.newHashSet();
+        Set<Link> currentSharedLinks;
+        Set<Link> currentLinks;
+        DeviceId egressToRemove = null;
+        // Let's find out the shared links
+        while (index < minLength) {
+            // Initialize the intersection with the paths related to the first egress
+            currentPaths = availablePaths.get(egresses.stream().findFirst().orElse(null));
+            currentSharedLinks = Sets.newHashSet();
+            // Iterate over the paths and take the "index" links
+            for (Path path : currentPaths) {
+                currentSharedLinks.add(path.links().get(index));
+            }
+            // Iterate over the remaining egress
+            for (DeviceId egress : egresses) {
+                // Iterate over the paths and take the "index" links
+                currentLinks = Sets.newHashSet();
+                for (Path path : availablePaths.get(egress)) {
+                    currentLinks.add(path.links().get(index));
+                }
+                // Do intersection
+                currentSharedLinks = Sets.intersection(currentSharedLinks, currentLinks);
+                // If there are no shared paths exit and record the device to remove
+                // we have to retry with a subset of sinks
+                if (currentSharedLinks.isEmpty()) {
+                    egressToRemove = egress;
+                    index = minLength;
+                    break;
+                }
+            }
+            sharedLinks.addAll(currentSharedLinks);
+            index++;
+        }
+        // If the shared links is empty and there are egress let's retry another time with less sinks,
+        // we can still build optimal subtrees
+        if (sharedLinks.isEmpty() && egresses.size() > 1 && egressToRemove != null) {
+            egresses.remove(egressToRemove);
+            sharedLinks = exploreMcastTree(egresses, availablePaths);
+        }
+        return sharedLinks;
     }
+
+    /**
+     * Build Mcast tree having as root the given source and as leaves the given egress points.
+     *
+     * @param mcastIp multicast group
+     * @param source source of the tree
+     * @param sinks leaves of the tree
+     * @return the computed Mcast tree
+     */
+    Map<ConnectPoint, List<Path>> computeSinkMcastTree(IpAddress mcastIp,
+                                                       DeviceId source,
+                                                       Set<ConnectPoint> sinks) {
+        // Get the egress devices, remove source from the egress if present
+        Set<DeviceId> egresses = sinks.stream().map(ConnectPoint::deviceId)
+                .filter(deviceId -> !deviceId.equals(source)).collect(Collectors.toSet());
+        Map<DeviceId, List<Path>> mcastTree = computeMcastTree(mcastIp, source, egresses);
+        final Map<ConnectPoint, List<Path>> finalTree = Maps.newHashMap();
+        // We need to put back the source if it was originally present
+        sinks.forEach(sink -> {
+            List<Path> sinkPaths = mcastTree.get(sink.deviceId());
+            finalTree.put(sink, sinkPaths != null ? sinkPaths : ImmutableList.of());
+        });
+        return finalTree;
+    }
+
+    /**
+     * Build Mcast tree having as root the given source and as leaves the given egress.
+     *
+     * @param mcastIp multicast group
+     * @param source source of the tree
+     * @param egresses leaves of the tree
+     * @return the computed Mcast tree
+     */
+    private Map<DeviceId, List<Path>> computeMcastTree(IpAddress mcastIp,
+                                                       DeviceId source,
+                                                       Set<DeviceId> egresses) {
+        log.debug("Computing tree for Multicast group {}, source {} and leafs {}",
+                  mcastIp, source, egresses);
+        // Pre-compute all the paths
+        Map<DeviceId, List<Path>> availablePaths = Maps.newHashMap();
+        egresses.forEach(egress -> availablePaths.put(egress, getPaths(source, egress,
+                                                                       Collections.emptySet())));
+        // Explore the topology looking for shared links amongst the egresses
+        Set<Link> linksToEnforce = exploreMcastTree(Sets.newHashSet(egresses), availablePaths);
+        // Build the final paths enforcing the shared links between egress devices
+        availablePaths.clear();
+        egresses.forEach(egress -> availablePaths.put(egress, getPaths(source, egress,
+                                                                       linksToEnforce)));
+        return availablePaths;
+    }
+
+    /**
+     * Gets path from src to dst computed using the custom link weigher.
+     *
+     * @param src source device ID
+     * @param dst destination device ID
+     * @param linksToEnforce links to be enforced
+     * @return list of paths from src to dst
+     */
+    List<Path> getPaths(DeviceId src, DeviceId dst, Set<Link> linksToEnforce) {
+        final Topology currentTopology = topologyService.currentTopology();
+        final LinkWeigher linkWeigher = new SRLinkWeigher(srManager, src, linksToEnforce);
+        List<Path> allPaths = Lists.newArrayList(topologyService.getPaths(currentTopology, src, dst, linkWeigher));
+        log.trace("{} path(s) found from {} to {}", allPaths.size(), src, dst);
+        return allPaths;
+    }
+
+    /**
+     * Gets a stored path having dst as egress.
+     *
+     * @param dst destination device ID
+     * @param storedPaths paths list
+     * @return an optional path
+     */
+    Optional<? extends List<Link>> getStoredPath(DeviceId dst, Collection<? extends List<Link>> storedPaths) {
+        return storedPaths.stream()
+                .filter(path -> path.get(path.size() - 1).dst().deviceId().equals(dst))
+                .findFirst();
+    }
+
+    /**
+     * Returns a set of affected paths by the failed element.
+     *
+     * @param paths the paths to check
+     * @param failedElement the failed element
+     * @return the affected paths
+     */
+    Set<List<Link>> getAffectedPaths(Set<List<Link>> paths, Object failedElement) {
+        if (failedElement instanceof DeviceId) {
+            return getAffectedPathsByDevice(paths, failedElement);
+        }
+        return getAffectedPathsByLink(paths, failedElement);
+    }
+
+    private Set<List<Link>> getAffectedPathsByDevice(Set<List<Link>> paths, Object failedElement) {
+        DeviceId affectedDevice = (DeviceId) failedElement;
+        Set<List<Link>> affectedPaths = Sets.newHashSet();
+        paths.forEach(path -> {
+            if (path.stream().anyMatch(link -> link.src().deviceId().equals(affectedDevice))) {
+                affectedPaths.add(path);
+            }
+        });
+        return affectedPaths;
+    }
+
+    private Set<List<Link>> getAffectedPathsByLink(Set<List<Link>> paths, Object failedElement) {
+        Link affectedLink = (Link) failedElement;
+        Set<List<Link>> affectedPaths = Sets.newHashSet();
+        paths.forEach(path -> {
+            if (path.contains(affectedLink)) {
+                affectedPaths.add(path);
+            }
+        });
+        return affectedPaths;
+    }
+
+    /**
+     * Checks if the failure is affecting the transit device.
+     *
+     * @param devices the transit devices
+     * @param failedElement the failed element
+     * @return true if the failed element is affecting the transit devices
+     */
+    boolean isInfraFailure(Set<DeviceId> devices, Object failedElement) {
+        if (failedElement instanceof DeviceId) {
+            return isInfraFailureByDevice(devices, failedElement);
+        }
+        return true;
+    }
+
+    private boolean isInfraFailureByDevice(Set<DeviceId> devices, Object failedElement) {
+        DeviceId affectedDevice = (DeviceId) failedElement;
+        return devices.contains(affectedDevice);
+    }
+
+    /**
+     * Checks if a port is an infra port.
+     *
+     * @param connectPoint port to be checked
+     * @param storedPaths paths to be checked against
+     * @return true if the port is an infra port. False otherwise.
+     */
+    boolean isInfraPort(ConnectPoint connectPoint, Collection<? extends List<Link>> storedPaths) {
+        for (List<Link> path : storedPaths) {
+            if (path.stream().anyMatch(link -> link.src().equals(connectPoint) ||
+                    link.dst().equals(connectPoint))) {
+                return true;
+            }
+        }
+        return false;
+    }
+
 }