Refactors McastHandler to optimize network failures

Changes:
- Introduces a paths store
- Optimizes SINKS_ADDED, SINKS_REMOVED
- Uses SINKS_REMOVED to handle egress failures
- Optimizes SOURCES_ADDED, SOURCES_REMOVED
- Leverages SOURCES_REMOVED to handle ingress failures
- Optimizes infra failures

Change-Id: I16d264f58d6fe11cfce4a546f7b4ab82a9fcc21b
diff --git a/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
index 5533413..e46c1d8 100644
--- a/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
+++ b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
@@ -18,7 +18,6 @@
 
 import com.google.common.base.Objects;
 import com.google.common.collect.HashMultimap;
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
@@ -46,13 +45,10 @@
 import org.onosproject.net.flowobjective.ForwardingObjective;
 import org.onosproject.net.flowobjective.NextObjective;
 import org.onosproject.net.flowobjective.ObjectiveContext;
-import org.onosproject.net.topology.LinkWeigher;
-import org.onosproject.net.topology.Topology;
-import org.onosproject.net.topology.TopologyService;
-import org.onosproject.segmentrouting.SRLinkWeigher;
 import org.onosproject.segmentrouting.SegmentRoutingManager;
 import org.onosproject.store.serializers.KryoNamespaces;
 import org.onosproject.store.service.ConsistentMap;
+import org.onosproject.store.service.ConsistentMultimap;
 import org.onosproject.store.service.DistributedSet;
 import org.onosproject.store.service.Serializer;
 import org.onosproject.store.service.Versioned;
@@ -96,10 +92,10 @@
     // Internal elements
     private static final Logger log = LoggerFactory.getLogger(McastHandler.class);
     private final SegmentRoutingManager srManager;
-    private final TopologyService topologyService;
     private final McastUtils mcastUtils;
     private final ConsistentMap<McastStoreKey, NextObjective> mcastNextObjStore;
     private final ConsistentMap<McastRoleStoreKey, McastRole> mcastRoleStore;
+    private final ConsistentMultimap<McastPathStoreKey, List<Link>> mcastPathStore;
     private final DistributedSet<McastFilteringObjStoreKey> mcastFilteringObjStore;
     // Stability threshold for Mcast. Seconds
     private static final long MCAST_STABLITY_THRESHOLD = 5;
@@ -125,7 +121,6 @@
     public McastHandler(SegmentRoutingManager srManager) {
         ApplicationId coreAppId = srManager.coreService.getAppId(CoreService.CORE_APP_NAME);
         this.srManager = srManager;
-        this.topologyService = srManager.topologyService;
         KryoNamespace.Builder mcastKryo = new KryoNamespace.Builder()
                 .register(KryoNamespaces.API)
                 .register(new McastStoreKeySerializer(), McastStoreKey.class);
@@ -152,6 +147,14 @@
                 .withSerializer(Serializer.using(mcastKryo.build("McastHandler-FilteringObj")))
                 .build()
                 .asDistributedSet();
+        mcastKryo = new KryoNamespace.Builder()
+                .register(KryoNamespaces.API)
+                .register(new McastPathStoreKeySerializer(), McastPathStoreKey.class);
+        mcastPathStore = srManager.storageService
+                .<McastPathStoreKey, List<Link>>consistentMultimapBuilder()
+                .withName("onos-mcast-path-store")
+                .withSerializer(Serializer.using(mcastKryo.build("McastHandler-Path")))
+                .build();
         mcastUtils = new McastUtils(srManager, coreAppId, log);
         // Init the executor for the buckets corrector
         mcastCorrector.scheduleWithFixedDelay(new McastBucketCorrector(), 10,
@@ -193,8 +196,8 @@
     }
 
     private void initInternal() {
-        lastMcastChange.set(Instant.now());
         srManager.multicastRouteService.getRoutes().forEach(mcastRoute -> {
+            lastMcastChange.set(Instant.now());
             log.debug("Init group {}", mcastRoute.group());
             if (!mcastUtils.isLeader(mcastRoute.group())) {
                 log.debug("Skip {} due to lack of leadership", mcastRoute.group());
@@ -203,11 +206,11 @@
             McastRouteData mcastRouteData = srManager.multicastRouteService.routeData(mcastRoute);
             // For each source process the mcast tree
             srManager.multicastRouteService.sources(mcastRoute).forEach(source -> {
-                Map<ConnectPoint, List<ConnectPoint>> mcastPaths = Maps.newHashMap();
-                Set<DeviceId> visited = Sets.newHashSet();
-                List<ConnectPoint> currentPath = Lists.newArrayList(source);
-                mcastUtils.buildMcastPaths(mcastNextObjStore.asJavaMap(), source.deviceId(), visited, mcastPaths,
-                                           currentPath, mcastRoute.group(), source);
+                McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastRoute.group(), source);
+                Collection<? extends List<Link>> storedPaths = Versioned.valueOrElse(
+                        mcastPathStore.get(pathStoreKey), Lists.newArrayList());
+                Map<ConnectPoint, List<ConnectPoint>> mcastPaths = buildMcastPaths(storedPaths, mcastRoute.group(),
+                                                                                   source);
                 // Get all the sinks and process them
                 Set<ConnectPoint> sinks = processSinksToBeAdded(source, mcastRoute.group(),
                                                                 mcastRouteData.sinks());
@@ -221,10 +224,10 @@
                     log.debug("Skip {} for source {} nothing to do", mcastRoute.group(), source);
                     return;
                 }
-                Map<ConnectPoint, List<Path>> mcasTree = computeSinkMcastTree(mcastRoute.group(),
-                                                                              source.deviceId(), sinks);
-                mcasTree.forEach((sink, paths) -> processSinkAddedInternal(source, sink,
-                                                                           mcastRoute.group(), paths));
+                Map<ConnectPoint, List<Path>> mcasTree = mcastUtils.computeSinkMcastTree(mcastRoute.group(),
+                                                                                         source.deviceId(), sinks);
+                mcasTree.forEach((sink, paths) -> processSinkAddedInternal(source, sink, mcastRoute.group(),
+                                                                             null));
             });
         });
     }
@@ -238,6 +241,7 @@
         mcastNextObjStore.destroy();
         mcastRoleStore.destroy();
         mcastFilteringObjStore.destroy();
+        mcastPathStore.destroy();
         mcastUtils.terminate();
         log.info("Terminated");
     }
@@ -310,8 +314,8 @@
         }
         sources.forEach(source -> {
             Set<ConnectPoint> sinksToBeAdded = processSinksToBeAdded(source, mcastIp, sinks);
-            Map<ConnectPoint, List<Path>> mcasTree = computeSinkMcastTree(mcastIp, source.deviceId(),
-                                                                          sinksToBeAdded);
+            Map<ConnectPoint, List<Path>> mcasTree = mcastUtils.computeSinkMcastTree(mcastIp, source.deviceId(),
+                                                                                     sinksToBeAdded);
             mcasTree.forEach((sink, paths) -> processSinkAddedInternal(source, sink, mcastIp, paths));
         });
     }
@@ -339,42 +343,41 @@
             processRouteRemovedInternal(sourcesToBeRemoved, mcastIp);
             return;
         }
-        // Skip offline devices
-        Set<ConnectPoint> candidateSources = sourcesToBeRemoved.stream()
-                .filter(source -> srManager.deviceService.isAvailable(source.deviceId()))
-                .collect(Collectors.toSet());
-        if (candidateSources.isEmpty()) {
-            log.debug("Skip {} due to empty sources to be removed", mcastIp);
-            return;
-        }
         // Let's heal the trees
-        Set<Link> remainingLinks = Sets.newHashSet();
-        Map<ConnectPoint, Set<Link>> candidateLinks = Maps.newHashMap();
+        Set<Link> notAffectedLinks = Sets.newHashSet();
+        Map<ConnectPoint, Set<Link>> affectedLinks = Maps.newHashMap();
         Map<ConnectPoint, Set<ConnectPoint>> candidateSinks = Maps.newHashMap();
-        Set<ConnectPoint> totalSources = Sets.newHashSet(candidateSources);
+        Set<ConnectPoint> totalSources = Sets.newHashSet(sourcesToBeRemoved);
         totalSources.addAll(remainingSources);
-        // Calculate all the links used by the sources
+        // Calculate all the links used by the sources and the current sinks
         totalSources.forEach(source -> {
             Set<ConnectPoint> currentSinks = sinks.values()
                     .stream().flatMap(Collection::stream)
                     .filter(sink -> isSinkForSource(mcastIp, sink, source))
                     .collect(Collectors.toSet());
             candidateSinks.put(source, currentSinks);
+            McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastIp, source);
+            Collection<? extends List<Link>> storedPaths = Versioned.valueOrElse(
+                    mcastPathStore.get(pathStoreKey), Lists.newArrayList());
             currentSinks.forEach(currentSink -> {
-                Optional<Path> currentPath = getPath(source.deviceId(), currentSink.deviceId(),
-                                                     mcastIp, null, source);
+                Optional<? extends List<Link>> currentPath = mcastUtils.getStoredPath(currentSink.deviceId(),
+                                                                                      storedPaths);
                 if (currentPath.isPresent()) {
-                    if (!candidateSources.contains(source)) {
-                        remainingLinks.addAll(currentPath.get().links());
+                    if (!sourcesToBeRemoved.contains(source)) {
+                        notAffectedLinks.addAll(currentPath.get());
                     } else {
-                        candidateLinks.put(source, Sets.newHashSet(currentPath.get().links()));
+                        affectedLinks.compute(source, (k, v) -> {
+                           v = v == null ? Sets.newHashSet() : v;
+                           v.addAll(currentPath.get());
+                           return v;
+                        });
                     }
                 }
             });
         });
         // Clean transit links
-        candidateLinks.forEach((source, currentCandidateLinks) -> {
-            Set<Link> linksToBeRemoved = Sets.difference(currentCandidateLinks, remainingLinks)
+        affectedLinks.forEach((source, currentCandidateLinks) -> {
+            Set<Link> linksToBeRemoved = Sets.difference(currentCandidateLinks, notAffectedLinks)
                     .immutableCopy();
             if (!linksToBeRemoved.isEmpty()) {
                 currentCandidateLinks.forEach(link -> {
@@ -391,8 +394,9 @@
             }
         });
         // Clean ingress and egress
-        candidateSources.forEach(source -> {
+        sourcesToBeRemoved.forEach(source -> {
             Set<ConnectPoint> currentSinks = candidateSinks.get(source);
+            McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastIp, source);
             currentSinks.forEach(currentSink -> {
                 VlanId assignedVlan = mcastUtils.assignedVlan(source.deviceId().equals(currentSink.deviceId()) ?
                                                                       source : null);
@@ -435,6 +439,8 @@
                 mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, currentSink.deviceId(),
                                                             source));
             });
+            // Clean the mcast paths
+            mcastPathStore.removeAll(pathStoreKey);
         });
     }
 
@@ -469,6 +475,7 @@
                     .stream().findFirst().orElse(null);
             Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT, source);
             Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS, source);
+            McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastIp, source);
             // If there are no egress and transit devices, sinks could be only on the ingress
             if (!egressDevices.isEmpty()) {
                 egressDevices.forEach(deviceId -> {
@@ -486,6 +493,8 @@
                 removeGroupFromDevice(ingressDevice, mcastIp, mcastUtils.assignedVlan(source));
                 mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, ingressDevice, source));
             }
+            // Clean the mcast paths
+            mcastPathStore.removeAll(pathStoreKey);
         });
         // Finally, withdraw the leadership
         mcastUtils.withdrawLeader(mcastIp);
@@ -509,21 +518,48 @@
             log.debug("Skip {} due to lack of leadership", mcastIp);
             return;
         }
-        Map<ConnectPoint, Map<ConnectPoint, Optional<Path>>> treesToBeRemoved = Maps.newHashMap();
+        Map<ConnectPoint, Map<ConnectPoint, Optional<? extends List<Link>>>> treesToBeRemoved = Maps.newHashMap();
         Map<ConnectPoint, Set<ConnectPoint>> treesToBeAdded = Maps.newHashMap();
+        Set<Link> goodLinks = Sets.newHashSet();
+        Map<ConnectPoint, Set<DeviceId>> goodDevicesBySource = Maps.newHashMap();
         sources.forEach(source -> {
             // Save the path associated to the sinks to be removed
-            Set<ConnectPoint> candidateSinks = processSinksToBeRemoved(mcastIp, prevSinks,
+            Set<ConnectPoint> sinksToBeRemoved = processSinksToBeRemoved(mcastIp, prevSinks,
                                                                          newSinks, source);
-            // Skip offline devices
-            Set<ConnectPoint> sinksToBeRemoved = candidateSinks.stream()
-                    .filter(sink -> srManager.deviceService.isAvailable(sink.deviceId()))
-                    .collect(Collectors.toSet());
-            Map<ConnectPoint, Optional<Path>> treeToBeRemoved = Maps.newHashMap();
-            sinksToBeRemoved.forEach(sink -> treeToBeRemoved.put(sink, getPath(source.deviceId(),
-                                                                               sink.deviceId(), mcastIp,
-                                                                               null, source)));
+            Map<ConnectPoint, Optional<? extends List<Link>>> treeToBeRemoved = Maps.newHashMap();
+            McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastIp, source);
+            Collection<? extends List<Link>> storedPaths = Versioned.valueOrElse(
+                    mcastPathStore.get(pathStoreKey), Lists.newArrayList());
+            sinksToBeRemoved.forEach(sink -> treeToBeRemoved.put(sink, mcastUtils.getStoredPath(sink.deviceId(),
+                                                                                                storedPaths)));
             treesToBeRemoved.put(source, treeToBeRemoved);
+            // Save the good links and good devices
+            Set<DeviceId> goodDevices = Sets.newHashSet();
+            Set<DeviceId> totalDevices = Sets.newHashSet(getDevice(mcastIp, EGRESS, source));
+            totalDevices.addAll(getDevice(mcastIp, INGRESS, source));
+            Set<ConnectPoint> notAffectedSinks = Sets.newHashSet();
+            // Compute good sinks
+            totalDevices.forEach(device -> {
+                Set<ConnectPoint> sinks = getSinks(mcastIp, device, source);
+                notAffectedSinks.addAll(Sets.difference(sinks, sinksToBeRemoved));
+            });
+            // Compute good paths and good devices
+            notAffectedSinks.forEach(notAffectedSink -> {
+                Optional<? extends List<Link>> notAffectedPath = mcastUtils.getStoredPath(notAffectedSink.deviceId(),
+                                                                                          storedPaths);
+                if (notAffectedPath.isPresent()) {
+                    List<Link> goodPath = notAffectedPath.get();
+                    goodLinks.addAll(goodPath);
+                    goodPath.forEach(link -> goodDevices.add(link.src().deviceId()));
+                } else {
+                    goodDevices.add(notAffectedSink.deviceId());
+                }
+            });
+            goodDevicesBySource.compute(source, (k, v) -> {
+                v = v == null ? Sets.newHashSet() : v;
+                v.addAll(goodDevices);
+                return v;
+            });
             // Recover the dual-homed sinks
             Set<ConnectPoint> sinksToBeRecovered = processSinksToBeRecovered(mcastIp, newSinks,
                                                                              prevSinks, source);
@@ -531,7 +567,8 @@
         });
         // Remove the sinks taking into account the multiple sources and the original paths
         treesToBeRemoved.forEach((source, tree) ->
-            tree.forEach((sink, path) -> processSinkRemovedInternal(source, sink, mcastIp, path)));
+            tree.forEach((sink, path) -> processSinkRemovedInternal(source, sink, mcastIp, path,
+                                                                    goodLinks, goodDevicesBySource.get(source))));
         // Add new sinks according to the recovery procedure
         treesToBeAdded.forEach((source, sinks) ->
             sinks.forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, null)));
@@ -544,9 +581,16 @@
      * @param sink connection point of the multicast sink
      * @param mcastIp multicast group IP address
      * @param mcastPath path associated to the sink
+     * @param usedLinks links used by the other sinks
+     * @param usedDevices devices used by other sinks
      */
     private void processSinkRemovedInternal(ConnectPoint source, ConnectPoint sink,
-                                            IpAddress mcastIp, Optional<Path> mcastPath) {
+                                            IpAddress mcastIp, Optional<? extends List<Link>> mcastPath,
+                                            Set<Link> usedLinks, Set<DeviceId> usedDevices) {
+
+        log.info("Used links {}", usedLinks);
+        log.info("Used devices {}", usedDevices);
+
         lastMcastChange.set(Instant.now());
         log.info("Processing sink removed {} for group {} and for source {}", sink, mcastIp, source);
         boolean isLast;
@@ -570,19 +614,24 @@
         }
         // If this is the last sink on the device, also update upstream
         if (mcastPath.isPresent()) {
-            List<Link> links = Lists.newArrayList(mcastPath.get().links());
-            Collections.reverse(links);
-            for (Link link : links) {
-                if (isLast) {
-                    isLast = removePortFromDevice(link.src().deviceId(), link.src().port(), mcastIp,
-                    mcastUtils.assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
-                    if (isLast) {
+            List<Link> links = Lists.newArrayList(mcastPath.get());
+            if (isLast) {
+                // Clean the path
+                McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastIp, source);
+                mcastPathStore.remove(pathStoreKey, mcastPath.get());
+                Collections.reverse(links);
+                for (Link link : links) {
+                    // If nobody is using the port remove
+                    if (!usedLinks.contains(link)) {
+                        removePortFromDevice(link.src().deviceId(), link.src().port(), mcastIp,
+                          mcastUtils.assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
+                    }
+                    // If nobody is using the device
+                    if (!usedDevices.contains(link.src().deviceId())) {
                         mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, link.src().deviceId(), source));
                     }
                 }
             }
-        } else {
-            log.warn("Unable to find a path from {} to {}. Abort sinkRemoved", source.deviceId(), sink.deviceId());
         }
     }
 
@@ -636,9 +685,10 @@
             return;
         }
         // Find a path. If present, create/update groups and flows for each hop
-        Optional<Path> mcastPath = getPath(source.deviceId(), sink.deviceId(), mcastIp, allPaths, source);
+        Optional<Path> mcastPath = getPath(source.deviceId(), sink.deviceId(), mcastIp, allPaths);
         if (mcastPath.isPresent()) {
             List<Link> links = mcastPath.get().links();
+            McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastIp, source);
             // Setup mcast role for ingress
             mcastRoleStore.put(new McastRoleStoreKey(mcastIp, source.deviceId(), source), INGRESS);
             // Setup properly the transit forwarding
@@ -653,15 +703,14 @@
             // Setup mcast role for the transit
             links.stream()
                     .filter(link -> !link.dst().deviceId().equals(sink.deviceId()))
-                    .forEach(link -> {
-                        log.trace("Transit links {}", link);
-                        mcastRoleStore.put(new McastRoleStoreKey(mcastIp, link.dst().deviceId(),
-                                source), TRANSIT);
-                    });
+                    .forEach(link -> mcastRoleStore.put(new McastRoleStoreKey(mcastIp, link.dst().deviceId(),
+                                                                              source), TRANSIT));
             // Process the egress device
             addPortToDevice(sink.deviceId(), sink.port(), mcastIp, mcastUtils.assignedVlan(null));
             // Setup mcast role for egress
             mcastRoleStore.put(new McastRoleStoreKey(mcastIp, sink.deviceId(), source), EGRESS);
+            // Store the used path
+            mcastPathStore.put(pathStoreKey, links);
         } else {
             log.warn("Unable to find a path from {} to {}. Abort sinkAdded", source.deviceId(), sink.deviceId());
         }
@@ -697,11 +746,11 @@
     }
 
     private void processLinkDownInternal(Link linkDown) {
-        lastMcastChange.set(Instant.now());
         // Get mcast groups affected by the link going down
         Set<IpAddress> affectedGroups = getAffectedGroups(linkDown);
         log.info("Processing link down {} for groups {}", linkDown, affectedGroups);
         affectedGroups.forEach(mcastIp -> {
+            lastMcastChange.set(Instant.now());
             log.debug("Processing link down {} for group {}", linkDown, mcastIp);
             recoverFailure(mcastIp, linkDown);
         });
@@ -717,12 +766,12 @@
     }
 
     private void processDeviceDownInternal(DeviceId deviceDown) {
-        lastMcastChange.set(Instant.now());
         // Get the mcast groups affected by the device going down
         Set<IpAddress> affectedGroups = getAffectedGroups(deviceDown);
         log.info("Processing device down {} for groups {}", deviceDown, affectedGroups);
         updateFilterObjStoreByDevice(deviceDown);
         affectedGroups.forEach(mcastIp -> {
+            lastMcastChange.set(Instant.now());
             log.debug("Processing device down {} for group {}", deviceDown, mcastIp);
             recoverFailure(mcastIp, deviceDown);
         });
@@ -735,151 +784,249 @@
      * @param failedElement the failed element
      */
     private void recoverFailure(IpAddress mcastIp, Object failedElement) {
-        // TODO Optimize when the group editing is in place
+        // Do not proceed if we are not the leaders
         if (!mcastUtils.isLeader(mcastIp)) {
             log.debug("Skip {} due to lack of leadership", mcastIp);
             return;
         }
+        // Skip if it is not an infra failure
+        Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
+        if (!mcastUtils.isInfraFailure(transitDevices, failedElement)) {
+            log.debug("Skip {} not an infrastructure failure", mcastIp);
+            return;
+        }
         // Do not proceed if the sources of this group are missing
         Set<ConnectPoint> sources = getSources(mcastIp);
         if (sources.isEmpty()) {
             log.warn("Missing sources for group {}", mcastIp);
             return;
         }
-        // Find out the ingress devices of the affected group
-        // If sinks are in other leafs, we have ingress, transit, egress, and source
-        // If sinks are in the same leaf, we have just ingress and source
-        Set<DeviceId> ingressDevices = getDevice(mcastIp, INGRESS);
-        if (ingressDevices.isEmpty()) {
-            log.warn("Missing ingress devices for group {}", mcastIp);
-            return;
-        }
-        // For each tree, delete ingress-transit part
-        sources.forEach(source -> {
-            Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT, source);
-            transitDevices.forEach(transitDevice -> {
-                removeGroupFromDevice(transitDevice, mcastIp, mcastUtils.assignedVlan(null));
-                mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, transitDevice, source));
+        // Get all the paths, affected paths, good links and good devices
+        Set<List<Link>> storedPaths = getStoredPaths(mcastIp);
+        Set<List<Link>> affectedPaths = mcastUtils.getAffectedPaths(storedPaths, failedElement);
+        Set<Link> goodLinks = Sets.newHashSet();
+        Map<DeviceId, Set<DeviceId>> goodDevicesBySource = Maps.newHashMap();
+        Map<DeviceId, Set<ConnectPoint>> processedSourcesByEgress = Maps.newHashMap();
+        Sets.difference(storedPaths, affectedPaths).forEach(goodPath -> {
+            goodLinks.addAll(goodPath);
+            DeviceId srcDevice = goodPath.get(0).src().deviceId();
+            Set<DeviceId> goodDevices = Sets.newHashSet();
+            goodPath.forEach(link -> goodDevices.add(link.src().deviceId()));
+            goodDevicesBySource.compute(srcDevice, (k, v) -> {
+                v = v == null ? Sets.newHashSet() : v;
+                v.addAll(goodDevices);
+                return v;
             });
         });
-        removeIngressTransitPorts(mcastIp, ingressDevices, sources);
-        // TODO Evaluate the possibility of building optimize trees between sources
-        Map<DeviceId, Set<ConnectPoint>> notRecovered = Maps.newHashMap();
-        sources.forEach(source -> {
-            Set<DeviceId> notRecoveredInternal = Sets.newHashSet();
-            DeviceId ingressDevice = ingressDevices.stream()
-                    .filter(deviceId -> deviceId.equals(source.deviceId())).findFirst().orElse(null);
-            // Clean also the ingress
-            if (failedElement instanceof DeviceId && ingressDevice.equals(failedElement)) {
-                removeGroupFromDevice((DeviceId) failedElement, mcastIp, mcastUtils.assignedVlan(source));
-                mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, (DeviceId) failedElement, source));
+        affectedPaths.forEach(affectedPath -> {
+            // TODO remove
+            log.info("Good links {}", goodLinks);
+            // TODO remove
+            log.info("Good devices {}", goodDevicesBySource);
+            // TODO trace
+            log.info("Healing the path {}", affectedPath);
+            DeviceId srcDevice = affectedPath.get(0).src().deviceId();
+            DeviceId dstDevice = affectedPath.get(affectedPath.size() - 1).dst().deviceId();
+            // Fix in one shot multiple sources
+            Set<ConnectPoint> affectedSources = sources.stream()
+                    .filter(device -> device.deviceId().equals(srcDevice))
+                    .collect(Collectors.toSet());
+            Set<ConnectPoint> processedSources = processedSourcesByEgress.getOrDefault(dstDevice,
+                                                                                       Collections.emptySet());
+            Optional<Path> alternativePath = getPath(srcDevice, dstDevice, mcastIp, null);
+            // If an alternative is possible go ahead
+            if (alternativePath.isPresent()) {
+                // TODO trace
+                log.info("Alternative path {}", alternativePath.get().links());
+            } else {
+                // Otherwise try to come up with an alternative
+                // TODO trace
+                log.info("No alternative path");
+                Set<ConnectPoint> notAffectedSources = Sets.difference(sources, affectedSources);
+                Set<ConnectPoint> remainingSources = Sets.difference(notAffectedSources, processedSources);
+                alternativePath = recoverSinks(dstDevice, mcastIp, affectedSources, remainingSources);
+                processedSourcesByEgress.compute(dstDevice, (k, v) -> {
+                    v = v == null ? Sets.newHashSet() : v;
+                    v.addAll(affectedSources);
+                    return v;
+                });
             }
-            if (ingressDevice == null) {
-                log.warn("Skip failure recovery - " +
-                                 "Missing ingress for source {} and group {}", source, mcastIp);
-                return;
-            }
-            Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS, source);
-            Map<DeviceId, List<Path>> mcastTree = computeMcastTree(mcastIp, ingressDevice, egressDevices);
-            // We have to verify, if there are egresses without paths
-            mcastTree.forEach((egressDevice, paths) -> {
-                Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
-                                                   mcastIp, paths, source);
-                // No paths, we have to try with alternative location
-                if (!mcastPath.isPresent()) {
-                    notRecovered.compute(egressDevice, (deviceId, listSources) -> {
-                        listSources = listSources == null ? Sets.newHashSet() : listSources;
-                        listSources.add(source);
-                        return listSources;
-                    });
-                    notRecoveredInternal.add(egressDevice);
-                }
-            });
-            // Fast path, we can recover all the locations
-            if (notRecoveredInternal.isEmpty()) {
-                mcastTree.forEach((egressDevice, paths) -> {
-                    Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
-                                                       mcastIp, paths, source);
-                    if (mcastPath.isPresent()) {
-                        installPath(mcastIp, source, mcastPath.get());
+            // Recover from the failure if possible
+            Optional<Path> finalPath = alternativePath;
+            affectedSources.forEach(affectedSource -> {
+                // Update the mcastPath store
+                McastPathStoreKey mcastPathStoreKey = new McastPathStoreKey(mcastIp, affectedSource);
+                // Verify if there are local sinks
+                Set<DeviceId> localSinks = getSinks(mcastIp, srcDevice, affectedSource).stream()
+                        .map(ConnectPoint::deviceId)
+                        .collect(Collectors.toSet());
+                Set<DeviceId> goodDevices = goodDevicesBySource.compute(affectedSource.deviceId(), (k, v) -> {
+                    v = v == null ? Sets.newHashSet() : v;
+                    v.addAll(localSinks);
+                    return v;
+                });
+                // TODO remove
+                log.info("Good devices {}", goodDevicesBySource);
+                Collection<? extends List<Link>> storedPathsBySource = Versioned.valueOrElse(
+                        mcastPathStore.get(mcastPathStoreKey), Lists.newArrayList());
+                Optional<? extends List<Link>> storedPath = storedPathsBySource.stream()
+                        .filter(path -> path.equals(affectedPath))
+                        .findFirst();
+                // Remove bad links
+                affectedPath.forEach(affectedLink -> {
+                    DeviceId affectedDevice = affectedLink.src().deviceId();
+                    // If there is overlap with good paths - skip it
+                    if (!goodLinks.contains(affectedLink)) {
+                        removePortFromDevice(affectedDevice, affectedLink.src().port(), mcastIp,
+                            mcastUtils.assignedVlan(affectedDevice.equals(affectedSource.deviceId()) ?
+                                                            affectedSource : null));
+                    }
+                    // Remove role on the affected links if last
+                    if (!goodDevices.contains(affectedDevice)) {
+                        mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, affectedDevice, affectedSource));
                     }
                 });
-            } else {
-                // Let's try to recover using alternative locations
-                recoverSinks(egressDevices, notRecoveredInternal, mcastIp,
-                             ingressDevice, source);
-            }
-        });
-        // Finally remove the egresses not recovered
-        notRecovered.forEach((egressDevice, listSources) -> {
-            Set<ConnectPoint> currentSources = getSources(mcastIp, egressDevice, EGRESS);
-            if (Objects.equal(currentSources, listSources)) {
-                log.warn("Fail to recover egress device {} from {} failure {}",
-                         egressDevice, failedElement instanceof Link ? "Link" : "Device", failedElement);
-                removeGroupFromDevice(egressDevice, mcastIp, mcastUtils.assignedVlan(null));
-            }
-            listSources.forEach(source -> mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, egressDevice, source)));
+                // Sometimes the removal fails for serialization issue
+                // trying with the original object as workaround
+                if (storedPath.isPresent()) {
+                    mcastPathStore.remove(mcastPathStoreKey, storedPath.get());
+                } else {
+                    log.warn("Unable to find the corresponding path - trying removeal");
+                    mcastPathStore.remove(mcastPathStoreKey, affectedPath);
+                }
+                // Program new links
+                if (finalPath.isPresent()) {
+                    List<Link> links = finalPath.get().links();
+                    installPath(mcastIp, affectedSource, links);
+                    mcastPathStore.put(mcastPathStoreKey, links);
+                    links.forEach(link -> goodDevices.add(link.src().deviceId()));
+                    goodDevicesBySource.compute(srcDevice, (k, v) -> {
+                        v = v == null ? Sets.newHashSet() : v;
+                        v.addAll(goodDevices);
+                        return v;
+                    });
+                    goodLinks.addAll(finalPath.get().links());
+                }
+            });
         });
     }
 
     /**
-     * Try to recover sinks using alternate locations.
+     * Try to recover sinks using alternative locations.
      *
-     * @param egressDevices the original egress devices
-     * @param notRecovered the devices not recovered
+     * @param notRecovered the device not recovered
      * @param mcastIp the group address
-     * @param ingressDevice the ingress device
-     * @param source the source connect point
+     * @param affectedSources affected sources
+     * @param goodSources sources not affected
      */
-    private void recoverSinks(Set<DeviceId> egressDevices, Set<DeviceId> notRecovered,
-                              IpAddress mcastIp, DeviceId ingressDevice, ConnectPoint source) {
-        log.debug("Processing recover sinks for group {} and for source {}",
-                  mcastIp, source);
-        Set<DeviceId> recovered = Sets.difference(egressDevices, notRecovered);
-        Set<ConnectPoint> totalAffectedSinks = Sets.newHashSet();
-        Set<ConnectPoint> totalSinks = Sets.newHashSet();
-        // Let's compute all the affected sinks and all the sinks
-        notRecovered.forEach(deviceId -> {
-            totalAffectedSinks.addAll(
-                    mcastUtils.getAffectedSinks(deviceId, mcastIp).values().stream()
-                            .flatMap(Collection::stream)
-                            .filter(connectPoint -> connectPoint.deviceId().equals(deviceId))
-                            .collect(Collectors.toSet())
-            );
-            totalSinks.addAll(
-                    mcastUtils.getAffectedSinks(deviceId, mcastIp).values().stream()
-                            .flatMap(Collection::stream).collect(Collectors.toSet())
-            );
+    private Optional<Path> recoverSinks(DeviceId notRecovered, IpAddress mcastIp,
+                                    Set<ConnectPoint> affectedSources,
+                                    Set<ConnectPoint> goodSources) {
+        log.debug("Processing recover sinks on {} for group {}", notRecovered, mcastIp);
+        Map<ConnectPoint, Set<ConnectPoint>> affectedSinksBySource = Maps.newHashMap();
+        Map<ConnectPoint, Set<ConnectPoint>> sinksBySource = Maps.newHashMap();
+        Set<ConnectPoint> sources = Sets.union(affectedSources, goodSources);
+        // Hosts influenced by the failure
+        Map<HostId, Set<ConnectPoint>> hostIdSetMap = mcastUtils.getAffectedSinks(notRecovered, mcastIp);
+        // Locations influenced by the failure
+        Set<ConnectPoint> affectedSinks = hostIdSetMap.values()
+                .stream()
+                .flatMap(Collection::stream)
+                .filter(connectPoint -> connectPoint.deviceId().equals(notRecovered))
+                .collect(Collectors.toSet());
+        // All locations
+        Set<ConnectPoint> sinks = hostIdSetMap.values()
+                .stream()
+                .flatMap(Collection::stream)
+                .collect(Collectors.toSet());
+        // Maps sinks with the sources
+        sources.forEach(source -> {
+            Set<ConnectPoint> currentSinks = affectedSinks.stream()
+                    .filter(sink -> isSinkForSource(mcastIp, sink, source))
+                    .collect(Collectors.toSet());
+            affectedSinksBySource.put(source, currentSinks);
         });
-        Set<ConnectPoint> sinksToBeAdded = Sets.difference(totalSinks, totalAffectedSinks);
-        Set<DeviceId> newEgressDevices = sinksToBeAdded.stream()
-                .map(ConnectPoint::deviceId).collect(Collectors.toSet());
-        newEgressDevices.addAll(recovered);
-        Set<DeviceId> copyNewEgressDevices = ImmutableSet.copyOf(newEgressDevices);
-        newEgressDevices = newEgressDevices.stream()
-                .filter(deviceId -> !deviceId.equals(ingressDevice)).collect(Collectors.toSet());
-        Map<DeviceId, List<Path>> mcastTree = computeMcastTree(mcastIp, ingressDevice, newEgressDevices);
-        // if the source was originally in the new locations, add new sinks
-        if (copyNewEgressDevices.contains(ingressDevice)) {
-            sinksToBeAdded.stream()
-                    .filter(connectPoint -> connectPoint.deviceId().equals(ingressDevice))
-                    .forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, ImmutableList.of()));
-        }
-        // Construct a new path for each egress device
-        mcastTree.forEach((egressDevice, paths) -> {
-            Optional<Path> mcastPath = getPath(ingressDevice, egressDevice, mcastIp, paths, source);
-            if (mcastPath.isPresent()) {
-                // Using recovery procedure
-                if (recovered.contains(egressDevice)) {
-                    installPath(mcastIp, source, mcastPath.get());
-                } else {
-                    // otherwise we need to threat as new sink
-                    sinksToBeAdded.stream()
-                            .filter(connectPoint -> connectPoint.deviceId().equals(egressDevice))
-                            .forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, paths));
+        // Remove sinks one by one if they are not used by other sources
+        affectedSources.forEach(affectedSource -> {
+            Set<ConnectPoint> currentSinks = affectedSinksBySource.get(affectedSource);
+            log.info("Current sinks {} for source {}", currentSinks, affectedSource);
+            currentSinks.forEach(currentSink -> {
+                VlanId assignedVlan = mcastUtils.assignedVlan(
+                        affectedSource.deviceId().equals(currentSink.deviceId()) ? affectedSource : null);
+                log.info("Assigned vlan {}", assignedVlan);
+                Set<VlanId> otherVlans = goodSources.stream()
+                        .filter(remainingSource -> affectedSinksBySource.get(remainingSource).contains(currentSink))
+                        .map(remainingSource -> mcastUtils.assignedVlan(
+                                remainingSource.deviceId().equals(currentSink.deviceId()) ? remainingSource : null))
+                        .collect(Collectors.toSet());
+                log.info("Other vlans {}", otherVlans);
+                // Sinks on other leaves
+                if (!otherVlans.contains(assignedVlan)) {
+                    removePortFromDevice(currentSink.deviceId(), currentSink.port(), mcastIp, assignedVlan);
                 }
-            }
+                mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, currentSink.deviceId(), affectedSource));
+            });
         });
+        // Get the sinks to be added and the new egress
+        Set<DeviceId> newEgress = Sets.newHashSet();
+        affectedSources.forEach(affectedSource -> {
+            Set<ConnectPoint> currentSinks = affectedSinksBySource.get(affectedSource);
+            Set<ConnectPoint> newSinks = Sets.difference(sinks, currentSinks);
+            sinksBySource.put(affectedSource, newSinks);
+            newSinks.stream()
+                    .map(ConnectPoint::deviceId)
+                    .forEach(newEgress::add);
+        });
+        log.info("newEgress {}", newEgress);
+        // If there are more than one new egresses, return the problem
+        if (newEgress.size() != 1) {
+            log.warn("There are {} new egress, wrong configuration. Abort.", newEgress.size());
+            return Optional.empty();
+        }
+        DeviceId egress = newEgress.stream()
+                .findFirst()
+                .orElse(null);
+        DeviceId ingress = affectedSources.stream()
+                .map(ConnectPoint::deviceId)
+                .findFirst()
+                .orElse(null);
+        log.info("Ingress {}", ingress);
+        if (ingress == null) {
+            log.warn("No new ingress, wrong configuration. Abort.");
+            return Optional.empty();
+        }
+        // Get an alternative path
+        Optional<Path> alternativePath = getPath(ingress, egress, mcastIp, null);
+        // If there are new path install sinks and return path
+        if (alternativePath.isPresent()) {
+            log.info("Alternative path {}", alternativePath.get().links());
+            affectedSources.forEach(affectedSource -> {
+                Set<ConnectPoint> newSinks = sinksBySource.get(affectedSource);
+                newSinks.forEach(newSink -> {
+                    addPortToDevice(newSink.deviceId(), newSink.port(), mcastIp, mcastUtils.assignedVlan(null));
+                    mcastRoleStore.put(new McastRoleStoreKey(mcastIp, newSink.deviceId(), affectedSource), EGRESS);
+                });
+            });
+            return alternativePath;
+        }
+        // No new path but sinks co-located with sources install sinks and return empty
+        if (ingress.equals(egress)) {
+            log.info("No Alternative path but sinks co-located");
+            affectedSources.forEach(affectedSource -> {
+                Set<ConnectPoint> newSinks = sinksBySource.get(affectedSource);
+                newSinks.forEach(newSink -> {
+                    if (affectedSource.port().equals(newSink.port())) {
+                        log.warn("Skip {} since sink {} is on the same port of source {}. Abort",
+                                 mcastIp, newSink, affectedSource);
+                        return;
+                    }
+                    addPortToDevice(newSink.deviceId(), newSink.port(), mcastIp,
+                                    mcastUtils.assignedVlan(affectedSource));
+                    mcastRoleStore.put(new McastRoleStoreKey(mcastIp, newSink.deviceId(), affectedSource), INGRESS);
+                });
+            });
+        }
+        return Optional.empty();
     }
 
     /**
@@ -1057,45 +1204,6 @@
     }
 
     /**
-     * Utility method to remove all the ingress transit ports.
-     *
-     * @param mcastIp the group ip
-     * @param ingressDevices the ingress devices
-     * @param sources the source connect points
-     */
-    private void removeIngressTransitPorts(IpAddress mcastIp, Set<DeviceId> ingressDevices,
-                                           Set<ConnectPoint> sources) {
-        Map<ConnectPoint, Set<PortNumber>> ingressTransitPorts = Maps.newHashMap();
-        sources.forEach(source -> {
-            DeviceId ingressDevice = ingressDevices.stream()
-                    .filter(deviceId -> deviceId.equals(source.deviceId()))
-                    .findFirst().orElse(null);
-            if (ingressDevice == null) {
-                log.warn("Skip removeIngressTransitPorts - " +
-                                 "Missing ingress for source {} and group {}",
-                         source, mcastIp);
-                return;
-            }
-            Set<PortNumber> ingressTransitPort = ingressTransitPort(mcastIp, ingressDevice, source);
-            if (ingressTransitPort.isEmpty()) {
-                log.warn("No transit ports to remove on device {}", ingressDevice);
-                return;
-            }
-            ingressTransitPorts.put(source, ingressTransitPort);
-        });
-        ingressTransitPorts.forEach((source, ports) -> ports.forEach(ingressTransitPort -> {
-            DeviceId ingressDevice = ingressDevices.stream()
-                    .filter(deviceId -> deviceId.equals(source.deviceId()))
-                    .findFirst().orElse(null);
-            boolean isLast = removePortFromDevice(ingressDevice, ingressTransitPort,
-                                                  mcastIp, mcastUtils.assignedVlan(source));
-            if (isLast) {
-                mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, ingressDevice, source));
-            }
-        }));
-    }
-
-    /**
      * Adds a port to given multicast group on given device. This involves the
      * update of L3 multicast group and multicast routing table entry.
      *
@@ -1106,6 +1214,8 @@
      */
     private void addPortToDevice(DeviceId deviceId, PortNumber port,
                                  IpAddress mcastIp, VlanId assignedVlan) {
+        // TODO trace
+        log.info("Adding {} on {}/{} and vlan {}", mcastIp, deviceId, port, assignedVlan);
         McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, deviceId, assignedVlan);
         ImmutableSet.Builder<PortNumber> portBuilder = ImmutableSet.builder();
         NextObjective newNextObj;
@@ -1184,6 +1294,8 @@
      */
     private boolean removePortFromDevice(DeviceId deviceId, PortNumber port,
                                          IpAddress mcastIp, VlanId assignedVlan) {
+        // TODO trace
+        log.info("Removing {} on {}/{} and vlan {}", mcastIp, deviceId, port, assignedVlan);
         McastStoreKey mcastStoreKey =
                 new McastStoreKey(mcastIp, deviceId, assignedVlan);
         // This device is not serving this multicast group
@@ -1247,6 +1359,8 @@
      */
     private void removeGroupFromDevice(DeviceId deviceId, IpAddress mcastIp,
                                        VlanId assignedVlan) {
+        // TODO trace
+        log.info("Removing {} on {} and vlan {}", mcastIp, deviceId, assignedVlan);
         McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, deviceId, assignedVlan);
         // This device is not serving this multicast group
         if (!mcastNextObjStore.containsKey(mcastStoreKey)) {
@@ -1268,8 +1382,7 @@
         mcastNextObjStore.remove(mcastStoreKey);
     }
 
-    private void installPath(IpAddress mcastIp, ConnectPoint source, Path mcastPath) {
-        List<Link> links = mcastPath.links();
+    private void installPath(IpAddress mcastIp, ConnectPoint source, List<Link> links) {
         if (links.isEmpty()) {
             log.warn("There is no link that can be used. Stopping installation.");
             return;
@@ -1281,7 +1394,7 @@
         // and a new filter objective on the destination port
         links.forEach(link -> {
             addPortToDevice(link.src().deviceId(), link.src().port(), mcastIp,
-                            mcastUtils.assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
+                mcastUtils.assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
             McastFilteringObjStoreKey mcastFilterObjStoreKey = new McastFilteringObjStoreKey(link.dst(),
                     mcastUtils.assignedVlan(null), mcastIp.isIp4());
             addFilterToDevice(mcastFilterObjStoreKey, mcastIp, null);
@@ -1294,145 +1407,6 @@
     }
 
     /**
-     * Go through all the paths, looking for shared links to be used
-     * in the final path computation.
-     *
-     * @param egresses egress devices
-     * @param availablePaths all the available paths towards the egress
-     * @return shared links between egress devices
-     */
-    private Set<Link> exploreMcastTree(Set<DeviceId> egresses,
-                                       Map<DeviceId, List<Path>> availablePaths) {
-        int minLength = Integer.MAX_VALUE;
-        int length;
-        List<Path> currentPaths;
-        // Verify the source can still reach all the egresses
-        for (DeviceId egress : egresses) {
-            // From the source we cannot reach all the sinks
-            // just continue and let's figure out after
-            currentPaths = availablePaths.get(egress);
-            if (currentPaths.isEmpty()) {
-                continue;
-            }
-            // Get the length of the first one available, update the min length
-            length = currentPaths.get(0).links().size();
-            if (length < minLength) {
-                minLength = length;
-            }
-        }
-        // If there are no paths
-        if (minLength == Integer.MAX_VALUE) {
-            return Collections.emptySet();
-        }
-        int index = 0;
-        Set<Link> sharedLinks = Sets.newHashSet();
-        Set<Link> currentSharedLinks;
-        Set<Link> currentLinks;
-        DeviceId egressToRemove = null;
-        // Let's find out the shared links
-        while (index < minLength) {
-            // Initialize the intersection with the paths related to the first egress
-            currentPaths = availablePaths.get(egresses.stream().findFirst().orElse(null));
-            currentSharedLinks = Sets.newHashSet();
-            // Iterate over the paths and take the "index" links
-            for (Path path : currentPaths) {
-                currentSharedLinks.add(path.links().get(index));
-            }
-            // Iterate over the remaining egress
-            for (DeviceId egress : egresses) {
-                // Iterate over the paths and take the "index" links
-                currentLinks = Sets.newHashSet();
-                for (Path path : availablePaths.get(egress)) {
-                    currentLinks.add(path.links().get(index));
-                }
-                // Do intersection
-                currentSharedLinks = Sets.intersection(currentSharedLinks, currentLinks);
-                // If there are no shared paths exit and record the device to remove
-                // we have to retry with a subset of sinks
-                if (currentSharedLinks.isEmpty()) {
-                    egressToRemove = egress;
-                    index = minLength;
-                    break;
-                }
-            }
-            sharedLinks.addAll(currentSharedLinks);
-            index++;
-        }
-        // If the shared links is empty and there are egress let's retry another time with less sinks,
-        // we can still build optimal subtrees
-        if (sharedLinks.isEmpty() && egresses.size() > 1 && egressToRemove != null) {
-            egresses.remove(egressToRemove);
-            sharedLinks = exploreMcastTree(egresses, availablePaths);
-        }
-        return sharedLinks;
-    }
-
-    /**
-     * Build Mcast tree having as root the given source and as leaves the given egress points.
-     *
-     * @param mcastIp multicast group
-     * @param source source of the tree
-     * @param sinks leaves of the tree
-     * @return the computed Mcast tree
-     */
-    private Map<ConnectPoint, List<Path>> computeSinkMcastTree(IpAddress mcastIp,
-                                                               DeviceId source,
-                                                               Set<ConnectPoint> sinks) {
-        // Get the egress devices, remove source from the egress if present
-        Set<DeviceId> egresses = sinks.stream().map(ConnectPoint::deviceId)
-                .filter(deviceId -> !deviceId.equals(source)).collect(Collectors.toSet());
-        Map<DeviceId, List<Path>> mcastTree = computeMcastTree(mcastIp, source, egresses);
-        final Map<ConnectPoint, List<Path>> finalTree = Maps.newHashMap();
-        // We need to put back the source if it was originally present
-        sinks.forEach(sink -> {
-            List<Path> sinkPaths = mcastTree.get(sink.deviceId());
-            finalTree.put(sink, sinkPaths != null ? sinkPaths : ImmutableList.of());
-        });
-        return finalTree;
-    }
-
-    /**
-     * Build Mcast tree having as root the given source and as leaves the given egress.
-     *
-     * @param mcastIp multicast group
-     * @param source source of the tree
-     * @param egresses leaves of the tree
-     * @return the computed Mcast tree
-     */
-    private Map<DeviceId, List<Path>> computeMcastTree(IpAddress mcastIp,
-                                                       DeviceId source,
-                                                       Set<DeviceId> egresses) {
-        log.debug("Computing tree for Multicast group {}, source {} and leafs {}",
-                  mcastIp, source, egresses);
-        // Pre-compute all the paths
-        Map<DeviceId, List<Path>> availablePaths = Maps.newHashMap();
-        egresses.forEach(egress -> availablePaths.put(egress, getPaths(source, egress,
-                                                                       Collections.emptySet())));
-        // Explore the topology looking for shared links amongst the egresses
-        Set<Link> linksToEnforce = exploreMcastTree(Sets.newHashSet(egresses), availablePaths);
-        // Build the final paths enforcing the shared links between egress devices
-        availablePaths.clear();
-        egresses.forEach(egress -> availablePaths.put(egress, getPaths(source, egress,
-                                                                       linksToEnforce)));
-        return availablePaths;
-    }
-
-    /**
-     * Gets path from src to dst computed using the custom link weigher.
-     *
-     * @param src source device ID
-     * @param dst destination device ID
-     * @return list of paths from src to dst
-     */
-    private List<Path> getPaths(DeviceId src, DeviceId dst, Set<Link> linksToEnforce) {
-        final Topology currentTopology = topologyService.currentTopology();
-        final LinkWeigher linkWeigher = new SRLinkWeigher(srManager, src, linksToEnforce);
-        List<Path> allPaths = Lists.newArrayList(topologyService.getPaths(currentTopology, src, dst, linkWeigher));
-        log.trace("{} path(s) found from {} to {}", allPaths.size(), src, dst);
-        return allPaths;
-    }
-
-    /**
      * Gets a path from src to dst.
      * If a path was allocated before, returns the allocated path.
      * Otherwise, randomly pick one from available paths.
@@ -1441,52 +1415,37 @@
      * @param dst destination device ID
      * @param mcastIp multicast group
      * @param allPaths paths list
+     *
      * @return an optional path from src to dst
      */
-    private Optional<Path> getPath(DeviceId src, DeviceId dst, IpAddress mcastIp,
-                                   List<Path> allPaths, ConnectPoint source) {
+    private Optional<Path> getPath(DeviceId src, DeviceId dst,
+                                   IpAddress mcastIp, List<Path> allPaths) {
         if (allPaths == null) {
-            allPaths = getPaths(src, dst, Collections.emptySet());
+            allPaths = mcastUtils.getPaths(src, dst, Collections.emptySet());
         }
         if (allPaths.isEmpty()) {
             return Optional.empty();
         }
         // Create a map index of suitability-to-list of paths. For example
-        // a path in the list associated to the index 1 shares only the
-        // first hop and it is less suitable of a path belonging to the index
-        // 2 that shares leaf-spine.
+        // a path in the list associated to the index 1 shares only one link
+        // and it is less suitable of a path belonging to the index 2
         Map<Integer, List<Path>> eligiblePaths = Maps.newHashMap();
-        int nhop;
-        McastStoreKey mcastStoreKey;
-        PortNumber srcPort;
-        Set<PortNumber> existingPorts;
-        NextObjective nextObj;
+        int score;
+        // Let's build the multicast tree
+        Set<List<Link>> storedPaths = getStoredPaths(mcastIp);
+        Set<Link> storedTree = storedPaths.stream()
+                .flatMap(Collection::stream).collect(Collectors.toSet());
+        log.trace("Stored tree {}", storedTree);
+        Set<Link> pathLinks;
         for (Path path : allPaths) {
             if (!src.equals(path.links().get(0).src().deviceId())) {
                 continue;
             }
-            nhop = 0;
-            // Iterate over the links
-            for (Link hop : path.links()) {
-                VlanId assignedVlan = mcastUtils.assignedVlan(hop.src().deviceId().equals(src) ?
-                                                                      source : null);
-                mcastStoreKey = new McastStoreKey(mcastIp, hop.src().deviceId(), assignedVlan);
-                // It does not exist in the store, go to the next link
-                if (!mcastNextObjStore.containsKey(mcastStoreKey)) {
-                    continue;
-                }
-                nextObj = mcastNextObjStore.get(mcastStoreKey).value();
-                existingPorts = mcastUtils.getPorts(nextObj.next());
-                srcPort = hop.src().port();
-                // the src port is not used as output, go to the next link
-                if (!existingPorts.contains(srcPort)) {
-                    continue;
-                }
-                nhop++;
-            }
-            // n_hop defines the index
-            if (nhop > 0) {
-                eligiblePaths.compute(nhop, (index, paths) -> {
+            pathLinks = Sets.newHashSet(path.links());
+            score = Sets.intersection(pathLinks, storedTree).size();
+            // score defines the index
+            if (score > 0) {
+                eligiblePaths.compute(score, (index, paths) -> {
                     paths = paths == null ? Lists.newArrayList() : paths;
                     paths.add(path);
                     return paths;
@@ -1509,6 +1468,19 @@
     }
 
     /**
+     * Gets stored paths of the group.
+     *
+     * @param mcastIp group address
+     * @return a collection of paths
+     */
+    private Set<List<Link>> getStoredPaths(IpAddress mcastIp) {
+        return mcastPathStore.stream()
+                .filter(entry -> entry.getKey().mcastIp().equals(mcastIp))
+                .map(Entry::getValue)
+                .collect(Collectors.toSet());
+    }
+
+    /**
      * Gets device(s) of given role and of given source in given multicast tree.
      *
      * @param mcastIp multicast IP
@@ -1539,21 +1511,6 @@
     }
 
     /**
-     * Gets source(s) of given role, given device in given multicast group.
-     *
-     * @param mcastIp multicast IP
-     * @param deviceId device id
-     * @param role multicast role
-     * @return set of device ID or empty set if not found
-     */
-    private Set<ConnectPoint> getSources(IpAddress mcastIp, DeviceId deviceId, McastRole role) {
-        return mcastRoleStore.entrySet().stream()
-                .filter(entry -> entry.getKey().mcastIp().equals(mcastIp) &&
-                        entry.getKey().deviceId().equals(deviceId) && entry.getValue().value() == role)
-                .map(Entry::getKey).map(McastRoleStoreKey::source).collect(Collectors.toSet());
-    }
-
-    /**
      * Gets source(s) of given multicast group.
      *
      * @param mcastIp multicast IP
@@ -1566,6 +1523,32 @@
     }
 
     /**
+     * Gets sink(s) of given multicast group.
+     *
+     * @param mcastIp multicast IP
+     * @return set of connect point or empty set if not found
+     */
+    private Set<ConnectPoint> getSinks(IpAddress mcastIp, DeviceId device, ConnectPoint source) {
+        McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastIp, source);
+        Collection<? extends List<Link>> storedPaths = Versioned.valueOrElse(
+                mcastPathStore.get(pathStoreKey), Lists.newArrayList());
+        VlanId assignedVlan = mcastUtils.assignedVlan(device.equals(source.deviceId()) ? source : null);
+        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, device, assignedVlan);
+        NextObjective nextObjective = Versioned.valueOrNull(mcastNextObjStore.get(mcastStoreKey));
+        ImmutableSet.Builder<ConnectPoint> cpBuilder = ImmutableSet.builder();
+        if (nextObjective != null) {
+            Set<PortNumber> outputPorts = mcastUtils.getPorts(nextObjective.next());
+            outputPorts.forEach(portNumber -> cpBuilder.add(new ConnectPoint(device, portNumber)));
+        }
+        Set<ConnectPoint> egressCp = cpBuilder.build();
+        return egressCp.stream()
+                .filter(connectPoint -> !mcastUtils.isInfraPort(connectPoint, storedPaths))
+                .collect(Collectors.toSet());
+    }
+
+
+
+    /**
      * Gets groups which is affected by the link down event.
      *
      * @param link link going down
@@ -1594,41 +1577,6 @@
     }
 
     /**
-     * Gets the spine-facing port on ingress device of given multicast group.
-     *
-     * @param mcastIp multicast IP
-     * @param ingressDevice the ingress device
-     * @param source the source connect point
-     * @return spine-facing port on ingress device
-     */
-    private Set<PortNumber> ingressTransitPort(IpAddress mcastIp, DeviceId ingressDevice,
-                                               ConnectPoint source) {
-        ImmutableSet.Builder<PortNumber> portBuilder = ImmutableSet.builder();
-        if (ingressDevice != null) {
-            Versioned<NextObjective> nextObjVers = mcastNextObjStore.get(new McastStoreKey(mcastIp, ingressDevice,
-                                                                          mcastUtils.assignedVlan(source)));
-            if (nextObjVers == null) {
-                log.warn("Absent next objective for {}", new McastStoreKey(mcastIp, ingressDevice,
-                        mcastUtils.assignedVlan(source)));
-                return portBuilder.build();
-            }
-            NextObjective nextObj = nextObjVers.value();
-            Set<PortNumber> ports = mcastUtils.getPorts(nextObj.next());
-            // Let's find out all the ingress-transit ports
-            for (PortNumber port : ports) {
-                // Spine-facing port should have no subnet and no xconnect
-                if (srManager.deviceConfiguration() != null &&
-                        srManager.deviceConfiguration().getPortSubnets(ingressDevice, port).isEmpty() &&
-                        (srManager.xconnectService == null ||
-                        !srManager.xconnectService.hasXconnect(new ConnectPoint(ingressDevice, port)))) {
-                    portBuilder.add(port);
-                }
-            }
-        }
-        return portBuilder.build();
-    }
-
-    /**
      * Verify if a given connect point is sink for this group.
      *
      * @param mcastIp group address
@@ -1683,7 +1631,7 @@
     private boolean isSinkReachable(IpAddress mcastIp, ConnectPoint sink,
                                     ConnectPoint source) {
         return sink.deviceId().equals(source.deviceId()) ||
-                getPath(source.deviceId(), sink.deviceId(), mcastIp, null, source).isPresent();
+                getPath(source.deviceId(), sink.deviceId(), mcastIp, null).isPresent();
     }
 
     /**
@@ -1760,7 +1708,9 @@
             mcastFilteringObjStore.add(filterObjStoreKey);
         } else {
             // do nothing
-            log.debug("Filtering already present. Abort");
+            log.debug("Filtering already present for connect point {}, vlan {} and {}. Abort",
+                      filterObjStoreKey.ingressCP(), filterObjStoreKey.vlanId(),
+                      filterObjStoreKey.isIpv4() ? "IPv4" : "IPv6");
         }
     }
 
@@ -1954,6 +1904,7 @@
      * @return the mapping mcastIp-device to next id
      */
     public Map<McastStoreKey, Integer> getNextIds(IpAddress mcastIp) {
+        log.info("mcastNexts {}", mcastNextObjStore.size());
         if (mcastIp != null) {
             return mcastNextObjStore.entrySet().stream()
                     .filter(mcastEntry -> mcastIp.equals(mcastEntry.getKey().mcastIp()))
@@ -1977,6 +1928,39 @@
     }
 
     /**
+     * Build the mcast paths.
+     *
+     * @param storedPaths mcast tree
+     * @param mcastIp the group ip
+     * @param source the source
+     */
+    private Map<ConnectPoint, List<ConnectPoint>> buildMcastPaths(Collection<? extends List<Link>> storedPaths,
+                                                                  IpAddress mcastIp, ConnectPoint source) {
+        Map<ConnectPoint, List<ConnectPoint>> mcastTree = Maps.newHashMap();
+        // Local sinks
+        Set<ConnectPoint> localSinks = getSinks(mcastIp, source.deviceId(), source);
+        localSinks.forEach(localSink -> mcastTree.put(localSink, Lists.newArrayList(localSink, source)));
+        // Remote sinks
+        storedPaths.forEach(path -> {
+            List<Link> links = path;
+            DeviceId egressDevice = links.get(links.size() - 1).dst().deviceId();
+            Set<ConnectPoint> remoteSinks = getSinks(mcastIp, egressDevice, source);
+            List<ConnectPoint> connectPoints = Lists.newArrayList(source);
+            links.forEach(link -> {
+                connectPoints.add(link.src());
+                connectPoints.add(link.dst());
+            });
+            Collections.reverse(connectPoints);
+            remoteSinks.forEach(remoteSink -> {
+                List<ConnectPoint> finalPath = Lists.newArrayList(connectPoints);
+                finalPath.add(0, remoteSink);
+                mcastTree.put(remoteSink, finalPath);
+            });
+        });
+        return mcastTree;
+    }
+
+    /**
      * Returns the associated roles to the mcast groups.
      *
      * @param mcastIp the group ip
@@ -1985,6 +1969,7 @@
      */
     public Map<McastRoleStoreKey, McastRole> getMcastRoles(IpAddress mcastIp,
                                                        ConnectPoint sourcecp) {
+        log.info("mcastRoles {}", mcastRoleStore.size());
         if (mcastIp != null) {
             Map<McastRoleStoreKey, McastRole> roles = mcastRoleStore.entrySet().stream()
                     .filter(mcastEntry -> mcastIp.equals(mcastEntry.getKey().mcastIp()))
@@ -2012,6 +1997,8 @@
      */
     public Multimap<ConnectPoint, List<ConnectPoint>> getMcastTrees(IpAddress mcastIp,
                                                                     ConnectPoint sourcecp) {
+        // TODO remove
+        log.info("{}", getStoredPaths(mcastIp));
         Multimap<ConnectPoint, List<ConnectPoint>> mcastTrees = HashMultimap.create();
         Set<ConnectPoint> sources = mcastUtils.getSources(mcastIp);
         if (sourcecp != null) {
@@ -2020,12 +2007,13 @@
         }
         if (!sources.isEmpty()) {
             sources.forEach(source -> {
-                Map<ConnectPoint, List<ConnectPoint>> mcastPaths = Maps.newHashMap();
-                Set<DeviceId> visited = Sets.newHashSet();
-                List<ConnectPoint> currentPath = Lists.newArrayList(source);
-                mcastUtils.buildMcastPaths(mcastNextObjStore.asJavaMap(), source.deviceId(), visited, mcastPaths,
-                        currentPath, mcastIp, source);
-                mcastPaths.forEach(mcastTrees::put);
+                McastPathStoreKey pathStoreKey = new McastPathStoreKey(mcastIp, source);
+                Collection<? extends List<Link>> storedPaths = Versioned.valueOrElse(
+                        mcastPathStore.get(pathStoreKey), Lists.newArrayList());
+                // TODO remove
+                log.info("Paths for group {} and source {} - {}", mcastIp, source, storedPaths.size());
+                Map<ConnectPoint, List<ConnectPoint>> mcastTree = buildMcastPaths(storedPaths, mcastIp, source);
+                mcastTree.forEach(mcastTrees::put);
             });
         }
         return mcastTrees;
@@ -2047,6 +2035,8 @@
      * @return the mapping group-node
      */
     public Map<DeviceId, List<McastFilteringObjStoreKey>> getMcastFilters() {
+        // TODO remove
+        log.info("mcastFilters {}", mcastFilteringObjStore.size());
         Map<DeviceId, List<McastFilteringObjStoreKey>> mapping = Maps.newHashMap();
         Set<McastFilteringObjStoreKey> currentKeys = Sets.newHashSet(mcastFilteringObjStore);
         currentKeys.forEach(filteringObjStoreKey ->
diff --git a/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastPathStoreKey.java b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastPathStoreKey.java
new file mode 100644
index 0000000..76f46fe
--- /dev/null
+++ b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastPathStoreKey.java
@@ -0,0 +1,100 @@
+/*
+ * Copyright 2020-present Open Networking Foundation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.onosproject.segmentrouting.mcast;
+
+import org.onlab.packet.IpAddress;
+import org.onosproject.net.ConnectPoint;
+
+import java.util.Objects;
+
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkNotNull;
+
+/**
+ * Key of multicast path store.
+ */
+public class McastPathStoreKey {
+    // Identify path using group address and source
+    private final IpAddress mcastIp;
+    private final ConnectPoint source;
+
+    /**
+     * Constructs the key of multicast path store.
+     *
+     * @param mcastIp multicast group IP address
+     * @param source source connect point
+     */
+    public McastPathStoreKey(IpAddress mcastIp, ConnectPoint source) {
+        checkNotNull(mcastIp, "mcastIp cannot be null");
+        checkNotNull(source, "source cannot be null");
+        checkArgument(mcastIp.isMulticast(), "mcastIp must be a multicast address");
+        this.mcastIp = mcastIp;
+        this.source = source;
+    }
+
+    // Constructor for serialization
+    private McastPathStoreKey() {
+        this.mcastIp = null;
+        this.source = null;
+    }
+
+    /**
+     * Returns the multicast IP address of this key.
+     *
+     * @return multicast IP
+     */
+    public IpAddress mcastIp() {
+        return mcastIp;
+    }
+
+    /**
+     * Returns the device ID of this key.
+     *
+     * @return device ID
+     */
+    public ConnectPoint source() {
+        return source;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+        if (this == o) {
+            return true;
+        }
+        if (!(o instanceof McastPathStoreKey)) {
+            return false;
+        }
+        McastPathStoreKey that =
+                (McastPathStoreKey) o;
+        return (Objects.equals(this.mcastIp, that.mcastIp) &&
+                Objects.equals(this.source, that.source));
+    }
+
+    @Override
+    public int hashCode() {
+         return Objects.hash(mcastIp, source);
+    }
+
+    @Override
+    public String toString() {
+        return toStringHelper(getClass())
+                .add("mcastIp", mcastIp)
+                .add("source", source)
+                .toString();
+    }
+}
diff --git a/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastPathStoreKeySerializer.java b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastPathStoreKeySerializer.java
new file mode 100644
index 0000000..4ef0194
--- /dev/null
+++ b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastPathStoreKeySerializer.java
@@ -0,0 +1,52 @@
+/*
+ * Copyright 2020-present Open Networking Foundation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.onosproject.segmentrouting.mcast;
+
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.onlab.packet.IpAddress;
+import org.onosproject.net.ConnectPoint;
+
+/**
+ * Custom serializer for {@link McastPathStoreKey}.
+ */
+class McastPathStoreKeySerializer extends Serializer<McastPathStoreKey> {
+
+    /**
+     * Creates {@link McastPathStoreKeySerializer} serializer instance.
+     */
+    McastPathStoreKeySerializer() {
+        // non-null, immutable
+        super(false, true);
+    }
+
+    @Override
+    public void write(Kryo kryo, Output output, McastPathStoreKey object) {
+        kryo.writeClassAndObject(output, object.mcastIp());
+        kryo.writeClassAndObject(output, object.source());
+    }
+
+    @Override
+    public McastPathStoreKey read(Kryo kryo, Input input, Class<McastPathStoreKey> type) {
+        IpAddress mcastIp = (IpAddress) kryo.readClassAndObject(input);
+        ConnectPoint source = (ConnectPoint) kryo.readClassAndObject(input);
+        return new McastPathStoreKey(mcastIp, source);
+    }
+}
diff --git a/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
index c4e1ad4..396b998 100644
--- a/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
+++ b/apps/segmentrouting/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
@@ -16,6 +16,7 @@
 
 package org.onosproject.segmentrouting.mcast;
 
+import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
@@ -34,6 +35,7 @@
 import org.onosproject.net.DeviceId;
 import org.onosproject.net.HostId;
 import org.onosproject.net.Link;
+import org.onosproject.net.Path;
 import org.onosproject.net.PortNumber;
 import org.onosproject.net.config.basics.McastConfig;
 import org.onosproject.net.flow.DefaultTrafficSelector;
@@ -50,6 +52,10 @@
 import org.onosproject.net.flowobjective.ForwardingObjective;
 import org.onosproject.net.flowobjective.NextObjective;
 import org.onosproject.net.flowobjective.ObjectiveContext;
+import org.onosproject.net.topology.LinkWeigher;
+import org.onosproject.net.topology.Topology;
+import org.onosproject.net.topology.TopologyService;
+import org.onosproject.segmentrouting.SRLinkWeigher;
 import org.onosproject.segmentrouting.SegmentRoutingManager;
 import org.onosproject.segmentrouting.SegmentRoutingService;
 import org.onosproject.segmentrouting.config.DeviceConfigNotFoundException;
@@ -57,8 +63,10 @@
 import org.slf4j.Logger;
 
 import java.util.Collection;
+import java.util.Collections;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -66,11 +74,11 @@
  * Utility class for Multicast Handler.
  */
 class McastUtils {
-
     // Internal reference to the log
     private final Logger log;
-    // Internal reference to SR Manager
-    private SegmentRoutingManager srManager;
+    // Internal reference to SR Manager and topology service
+    private final SegmentRoutingManager srManager;
+    private final TopologyService topologyService;
     // Internal reference to the app id
     private ApplicationId coreAppId;
     // Hashing function for the multicast hasher
@@ -87,6 +95,7 @@
      */
     McastUtils(SegmentRoutingManager srManager, ApplicationId coreAppId, Logger log) {
         this.srManager = srManager;
+        this.topologyService = srManager.topologyService;
         this.coreAppId = coreAppId;
         this.log = log;
         this.mcastLeaderCache = Maps.newConcurrentMap();
@@ -95,7 +104,7 @@
     /**
      * Clean up when deactivating the application.
      */
-    public void terminate() {
+    void terminate() {
         mcastLeaderCache.clear();
     }
 
@@ -233,6 +242,8 @@
         return srManager.getDefaultInternalVlan();
     }
 
+
+
     /**
      * Gets sources connect points of given multicast group.
      *
@@ -477,64 +488,228 @@
     }
 
     /**
-     * Build recursively the mcast paths.
+     * Go through all the paths, looking for shared links to be used
+     * in the final path computation.
      *
-     * @param mcastNextObjStore mcast next obj store
-     * @param toVisit the node to visit
-     * @param visited the visited nodes
-     * @param mcastPaths the current mcast paths
-     * @param currentPath the current path
-     * @param mcastIp the group ip
-     * @param source the source
+     * @param egresses egress devices
+     * @param availablePaths all the available paths towards the egress
+     * @return shared links between egress devices
      */
-    void buildMcastPaths(Map<McastStoreKey, NextObjective> mcastNextObjStore,
-                                 DeviceId toVisit, Set<DeviceId> visited,
-                                 Map<ConnectPoint, List<ConnectPoint>> mcastPaths,
-                                 List<ConnectPoint> currentPath, IpAddress mcastIp,
-                                 ConnectPoint source) {
-        log.debug("Building Multicast paths recursively for {} - next device to visit is {}",
-                  mcastIp, toVisit);
-        // If we have visited the node to visit there is a loop
-        if (visited.contains(toVisit)) {
-            return;
-        }
-        // Visit next-hop
-        visited.add(toVisit);
-        VlanId assignedVlan = assignedVlan(toVisit.equals(source.deviceId()) ? source : null);
-        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, toVisit, assignedVlan);
-        // Looking for next-hops
-        if (mcastNextObjStore.containsKey(mcastStoreKey)) {
-            // Build egress connect points, get ports and build relative cps
-            NextObjective nextObjective = mcastNextObjStore.get(mcastStoreKey);
-            Set<PortNumber> outputPorts = getPorts(nextObjective.next());
-            ImmutableSet.Builder<ConnectPoint> cpBuilder = ImmutableSet.builder();
-            outputPorts.forEach(portNumber -> cpBuilder.add(new ConnectPoint(toVisit, portNumber)));
-            Set<ConnectPoint> egressPoints = cpBuilder.build();
-            Set<Link> egressLinks;
-            List<ConnectPoint> newCurrentPath;
-            Set<DeviceId> newVisited;
-            DeviceId newToVisit;
-            for (ConnectPoint egressPoint : egressPoints) {
-                egressLinks = srManager.linkService.getEgressLinks(egressPoint);
-                // If it does not have egress links, stop
-                if (egressLinks.isEmpty()) {
-                    // Add the connect points to the path
-                    newCurrentPath = Lists.newArrayList(currentPath);
-                    newCurrentPath.add(0, egressPoint);
-                    mcastPaths.put(egressPoint, newCurrentPath);
-                } else {
-                    newVisited = Sets.newHashSet(visited);
-                    // Iterate over the egress links for the next hops
-                    for (Link egressLink : egressLinks) {
-                        newToVisit = egressLink.dst().deviceId();
-                        newCurrentPath = Lists.newArrayList(currentPath);
-                        newCurrentPath.add(0, egressPoint);
-                        newCurrentPath.add(0, egressLink.dst());
-                        buildMcastPaths(mcastNextObjStore, newToVisit, newVisited, mcastPaths, newCurrentPath, mcastIp,
-                                source);
-                    }
-                }
+    private Set<Link> exploreMcastTree(Set<DeviceId> egresses,
+                                       Map<DeviceId, List<Path>> availablePaths) {
+        int minLength = Integer.MAX_VALUE;
+        int length;
+        List<Path> currentPaths;
+        // Verify the source can still reach all the egresses
+        for (DeviceId egress : egresses) {
+            // From the source we cannot reach all the sinks
+            // just continue and let's figure out after
+            currentPaths = availablePaths.get(egress);
+            if (currentPaths.isEmpty()) {
+                continue;
+            }
+            // Get the length of the first one available, update the min length
+            length = currentPaths.get(0).links().size();
+            if (length < minLength) {
+                minLength = length;
             }
         }
+        // If there are no paths
+        if (minLength == Integer.MAX_VALUE) {
+            return Collections.emptySet();
+        }
+        int index = 0;
+        Set<Link> sharedLinks = Sets.newHashSet();
+        Set<Link> currentSharedLinks;
+        Set<Link> currentLinks;
+        DeviceId egressToRemove = null;
+        // Let's find out the shared links
+        while (index < minLength) {
+            // Initialize the intersection with the paths related to the first egress
+            currentPaths = availablePaths.get(egresses.stream().findFirst().orElse(null));
+            currentSharedLinks = Sets.newHashSet();
+            // Iterate over the paths and take the "index" links
+            for (Path path : currentPaths) {
+                currentSharedLinks.add(path.links().get(index));
+            }
+            // Iterate over the remaining egress
+            for (DeviceId egress : egresses) {
+                // Iterate over the paths and take the "index" links
+                currentLinks = Sets.newHashSet();
+                for (Path path : availablePaths.get(egress)) {
+                    currentLinks.add(path.links().get(index));
+                }
+                // Do intersection
+                currentSharedLinks = Sets.intersection(currentSharedLinks, currentLinks);
+                // If there are no shared paths exit and record the device to remove
+                // we have to retry with a subset of sinks
+                if (currentSharedLinks.isEmpty()) {
+                    egressToRemove = egress;
+                    index = minLength;
+                    break;
+                }
+            }
+            sharedLinks.addAll(currentSharedLinks);
+            index++;
+        }
+        // If the shared links is empty and there are egress let's retry another time with less sinks,
+        // we can still build optimal subtrees
+        if (sharedLinks.isEmpty() && egresses.size() > 1 && egressToRemove != null) {
+            egresses.remove(egressToRemove);
+            sharedLinks = exploreMcastTree(egresses, availablePaths);
+        }
+        return sharedLinks;
     }
+
+    /**
+     * Build Mcast tree having as root the given source and as leaves the given egress points.
+     *
+     * @param mcastIp multicast group
+     * @param source source of the tree
+     * @param sinks leaves of the tree
+     * @return the computed Mcast tree
+     */
+    Map<ConnectPoint, List<Path>> computeSinkMcastTree(IpAddress mcastIp,
+                                                       DeviceId source,
+                                                       Set<ConnectPoint> sinks) {
+        // Get the egress devices, remove source from the egress if present
+        Set<DeviceId> egresses = sinks.stream().map(ConnectPoint::deviceId)
+                .filter(deviceId -> !deviceId.equals(source)).collect(Collectors.toSet());
+        Map<DeviceId, List<Path>> mcastTree = computeMcastTree(mcastIp, source, egresses);
+        final Map<ConnectPoint, List<Path>> finalTree = Maps.newHashMap();
+        // We need to put back the source if it was originally present
+        sinks.forEach(sink -> {
+            List<Path> sinkPaths = mcastTree.get(sink.deviceId());
+            finalTree.put(sink, sinkPaths != null ? sinkPaths : ImmutableList.of());
+        });
+        return finalTree;
+    }
+
+    /**
+     * Build Mcast tree having as root the given source and as leaves the given egress.
+     *
+     * @param mcastIp multicast group
+     * @param source source of the tree
+     * @param egresses leaves of the tree
+     * @return the computed Mcast tree
+     */
+    private Map<DeviceId, List<Path>> computeMcastTree(IpAddress mcastIp,
+                                                       DeviceId source,
+                                                       Set<DeviceId> egresses) {
+        log.debug("Computing tree for Multicast group {}, source {} and leafs {}",
+                  mcastIp, source, egresses);
+        // Pre-compute all the paths
+        Map<DeviceId, List<Path>> availablePaths = Maps.newHashMap();
+        egresses.forEach(egress -> availablePaths.put(egress, getPaths(source, egress,
+                                                                       Collections.emptySet())));
+        // Explore the topology looking for shared links amongst the egresses
+        Set<Link> linksToEnforce = exploreMcastTree(Sets.newHashSet(egresses), availablePaths);
+        // Build the final paths enforcing the shared links between egress devices
+        availablePaths.clear();
+        egresses.forEach(egress -> availablePaths.put(egress, getPaths(source, egress,
+                                                                       linksToEnforce)));
+        return availablePaths;
+    }
+
+    /**
+     * Gets path from src to dst computed using the custom link weigher.
+     *
+     * @param src source device ID
+     * @param dst destination device ID
+     * @param linksToEnforce links to be enforced
+     * @return list of paths from src to dst
+     */
+    List<Path> getPaths(DeviceId src, DeviceId dst, Set<Link> linksToEnforce) {
+        final Topology currentTopology = topologyService.currentTopology();
+        final LinkWeigher linkWeigher = new SRLinkWeigher(srManager, src, linksToEnforce);
+        List<Path> allPaths = Lists.newArrayList(topologyService.getPaths(currentTopology, src, dst, linkWeigher));
+        log.trace("{} path(s) found from {} to {}", allPaths.size(), src, dst);
+        return allPaths;
+    }
+
+    /**
+     * Gets a stored path having dst as egress.
+     *
+     * @param dst destination device ID
+     * @param storedPaths paths list
+     * @return an optional path
+     */
+    Optional<? extends List<Link>> getStoredPath(DeviceId dst, Collection<? extends List<Link>> storedPaths) {
+        return storedPaths.stream()
+                .filter(path -> path.get(path.size() - 1).dst().deviceId().equals(dst))
+                .findFirst();
+    }
+
+    /**
+     * Returns a set of affected paths by the failed element.
+     *
+     * @param paths the paths to check
+     * @param failedElement the failed element
+     * @return the affected paths
+     */
+    Set<List<Link>> getAffectedPaths(Set<List<Link>> paths, Object failedElement) {
+        if (failedElement instanceof DeviceId) {
+            return getAffectedPathsByDevice(paths, failedElement);
+        }
+        return getAffectedPathsByLink(paths, failedElement);
+    }
+
+    private Set<List<Link>> getAffectedPathsByDevice(Set<List<Link>> paths, Object failedElement) {
+        DeviceId affectedDevice = (DeviceId) failedElement;
+        Set<List<Link>> affectedPaths = Sets.newHashSet();
+        paths.forEach(path -> {
+            if (path.stream().anyMatch(link -> link.src().deviceId().equals(affectedDevice))) {
+                affectedPaths.add(path);
+            }
+        });
+        return affectedPaths;
+    }
+
+    private Set<List<Link>> getAffectedPathsByLink(Set<List<Link>> paths, Object failedElement) {
+        Link affectedLink = (Link) failedElement;
+        Set<List<Link>> affectedPaths = Sets.newHashSet();
+        paths.forEach(path -> {
+            if (path.contains(affectedLink)) {
+                affectedPaths.add(path);
+            }
+        });
+        return affectedPaths;
+    }
+
+    /**
+     * Checks if the failure is affecting the transit device.
+     *
+     * @param devices the transit devices
+     * @param failedElement the failed element
+     * @return true if the failed element is affecting the transit devices
+     */
+    boolean isInfraFailure(Set<DeviceId> devices, Object failedElement) {
+        if (failedElement instanceof DeviceId) {
+            return isInfraFailureByDevice(devices, failedElement);
+        }
+        return true;
+    }
+
+    private boolean isInfraFailureByDevice(Set<DeviceId> devices, Object failedElement) {
+        DeviceId affectedDevice = (DeviceId) failedElement;
+        return devices.contains(affectedDevice);
+    }
+
+    /**
+     * Checks if a port is an infra port.
+     *
+     * @param connectPoint port to be checked
+     * @param storedPaths paths to be checked against
+     * @return true if the port is an infra port. False otherwise.
+     */
+    boolean isInfraPort(ConnectPoint connectPoint, Collection<? extends List<Link>> storedPaths) {
+        for (List<Link> path : storedPaths) {
+            if (path.stream().anyMatch(link -> link.src().equals(connectPoint) ||
+                    link.dst().equals(connectPoint))) {
+                return true;
+            }
+        }
+        return false;
+    }
+
 }