[CORD-2839] Handling multiple sources

Change-Id: I77bd98e8a12e5044421ef5e0b048833dd688cb2e
diff --git a/app/src/main/java/org/onosproject/segmentrouting/SegmentRoutingManager.java b/app/src/main/java/org/onosproject/segmentrouting/SegmentRoutingManager.java
index b11de60..156240a 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/SegmentRoutingManager.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/SegmentRoutingManager.java
@@ -108,6 +108,7 @@
 import org.onosproject.segmentrouting.grouphandler.NextNeighbors;
 import org.onosproject.segmentrouting.mcast.McastHandler;
 import org.onosproject.segmentrouting.mcast.McastRole;
+import org.onosproject.segmentrouting.mcast.McastRoleStoreKey;
 import org.onosproject.segmentrouting.pwaas.DefaultL2Tunnel;
 import org.onosproject.segmentrouting.pwaas.DefaultL2TunnelDescription;
 import org.onosproject.segmentrouting.pwaas.DefaultL2TunnelHandler;
@@ -120,7 +121,7 @@
 
 import org.onosproject.segmentrouting.storekey.DestinationSetNextObjectiveStoreKey;
 import org.onosproject.segmentrouting.storekey.DummyVlanIdStoreKey;
-import org.onosproject.segmentrouting.storekey.McastStoreKey;
+import org.onosproject.segmentrouting.mcast.McastStoreKey;
 import org.onosproject.segmentrouting.storekey.PortNextObjectiveStoreKey;
 import org.onosproject.segmentrouting.storekey.VlanNextObjectiveStoreKey;
 import org.onosproject.segmentrouting.storekey.XConnectStoreKey;
@@ -751,6 +752,11 @@
     }
 
     @Override
+    public Map<McastRoleStoreKey, McastRole> getMcastRoles(IpAddress mcastIp, ConnectPoint sourcecp) {
+        return mcastHandler.getMcastRoles(mcastIp, sourcecp);
+    }
+
+    @Override
     public Map<ConnectPoint, List<ConnectPoint>> getMcastPaths(IpAddress mcastIp) {
         return mcastHandler.getMcastPaths(mcastIp);
     }
diff --git a/app/src/main/java/org/onosproject/segmentrouting/SegmentRoutingService.java b/app/src/main/java/org/onosproject/segmentrouting/SegmentRoutingService.java
index f52fad1..03ef5b5 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/SegmentRoutingService.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/SegmentRoutingService.java
@@ -25,6 +25,7 @@
 import org.onosproject.net.PortNumber;
 import org.onosproject.segmentrouting.grouphandler.NextNeighbors;
 import org.onosproject.segmentrouting.mcast.McastRole;
+import org.onosproject.segmentrouting.mcast.McastRoleStoreKey;
 import org.onosproject.segmentrouting.pwaas.DefaultL2TunnelDescription;
 import org.onosproject.segmentrouting.pwaas.L2Tunnel;
 import org.onosproject.segmentrouting.pwaas.L2TunnelHandler;
@@ -33,7 +34,7 @@
 import org.onosproject.segmentrouting.storekey.DestinationSetNextObjectiveStoreKey;
 
 import com.google.common.collect.ImmutableMap;
-import org.onosproject.segmentrouting.storekey.McastStoreKey;
+import org.onosproject.segmentrouting.mcast.McastStoreKey;
 
 import java.util.List;
 import java.util.Map;
@@ -223,7 +224,7 @@
      */
     ImmutableMap<DeviceId, Set<PortNumber>> getDownedPortState();
 
-   /**
+    /**
      * Returns the associated next ids to the mcast groups or to the single
      * group if mcastIp is present.
      *
@@ -245,6 +246,16 @@
     Map<McastStoreKey, McastRole> getMcastRoles(IpAddress mcastIp);
 
     /**
+     * Returns the associated roles to the mcast groups.
+     *
+     * @param mcastIp the group ip
+     * @param sourcecp the source connect point
+     * @return the mapping mcastIp-device to mcast role
+     */
+    Map<McastRoleStoreKey, McastRole> getMcastRoles(IpAddress mcastIp,
+                                                    ConnectPoint sourcecp);
+
+    /**
      * Returns the associated paths to the mcast group.
      *
      * @param mcastIp the group ip
@@ -255,7 +266,6 @@
     @Deprecated
     Map<ConnectPoint, List<ConnectPoint>> getMcastPaths(IpAddress mcastIp);
 
-
     /**
      * Returns the associated trees to the mcast group.
      *
diff --git a/app/src/main/java/org/onosproject/segmentrouting/cli/McastNextListCommand.java b/app/src/main/java/org/onosproject/segmentrouting/cli/McastNextListCommand.java
index e5d8b12..cefb244 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/cli/McastNextListCommand.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/cli/McastNextListCommand.java
@@ -20,11 +20,13 @@
 import org.apache.karaf.shell.commands.Command;
 import org.apache.karaf.shell.commands.Option;
 import org.onlab.packet.IpAddress;
+import org.onlab.packet.VlanId;
 import org.onosproject.cli.AbstractShellCommand;
 import org.onosproject.mcast.cli.McastGroupCompleter;
 import org.onosproject.net.DeviceId;
 import org.onosproject.segmentrouting.SegmentRoutingService;
-import org.onosproject.segmentrouting.storekey.McastStoreKey;
+import org.onosproject.segmentrouting.mcast.McastStoreKey;
+import org.apache.commons.lang3.tuple.Pair;
 
 import java.util.Map;
 import java.util.Set;
@@ -69,19 +71,20 @@
         // Print the nextids for each group
         mcastGroups.forEach(group -> {
             // Create a new map for the group
-            Map<DeviceId, Integer> deviceIdNextMap = Maps.newHashMap();
+            Map<Pair<DeviceId, VlanId>, Integer> deviceIdNextMap = Maps.newHashMap();
             keyToNextId.entrySet()
                     .stream()
                     // Filter only the elements related to this group
                     .filter(entry -> entry.getKey().mcastIp().equals(group))
                     // For each create a new entry in the group related map
-                    .forEach(entry -> deviceIdNextMap.put(entry.getKey().deviceId(), entry.getValue()));
+                    .forEach(entry -> deviceIdNextMap.put(Pair.of(entry.getKey().deviceId(),
+                                                          entry.getKey().vlanId()), entry.getValue()));
             // Print the map
             printMcastNext(group, deviceIdNextMap);
         });
     }
 
-    private void printMcastNext(IpAddress mcastGroup, Map<DeviceId, Integer> deviceIdNextMap) {
+    private void printMcastNext(IpAddress mcastGroup, Map<Pair<DeviceId, VlanId>, Integer> deviceIdNextMap) {
         print(FORMAT_MAPPING, mcastGroup, deviceIdNextMap);
     }
 }
diff --git a/app/src/main/java/org/onosproject/segmentrouting/cli/McastRoleListCommand.java b/app/src/main/java/org/onosproject/segmentrouting/cli/McastRoleListCommand.java
new file mode 100644
index 0000000..facd6a1
--- /dev/null
+++ b/app/src/main/java/org/onosproject/segmentrouting/cli/McastRoleListCommand.java
@@ -0,0 +1,111 @@
+/*
+ * Copyright 2018-present Open Networking Foundation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.onosproject.segmentrouting.cli;
+
+import com.google.common.collect.ArrayListMultimap;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Multimap;
+import org.apache.karaf.shell.commands.Command;
+import org.apache.karaf.shell.commands.Option;
+import org.onlab.packet.IpAddress;
+import org.onosproject.cli.AbstractShellCommand;
+import org.onosproject.mcast.cli.McastGroupCompleter;
+import org.onosproject.net.ConnectPoint;
+import org.onosproject.net.DeviceId;
+import org.onosproject.segmentrouting.SegmentRoutingService;
+import org.onosproject.segmentrouting.mcast.McastRole;
+import org.onosproject.segmentrouting.mcast.McastRoleStoreKey;
+
+import java.util.Collection;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+import static com.google.common.base.Strings.isNullOrEmpty;
+
+/**
+ * Command to show the list of mcast roles.
+ */
+@Command(scope = "onos", name = "sr-mcast-role",
+        description = "Lists all mcast roles")
+public class McastRoleListCommand extends AbstractShellCommand {
+
+    // OSGi workaround to introduce package dependency
+    McastGroupCompleter completer;
+
+    // Format for group line
+    private static final String FORMAT_MAPPING = "%s,%s  ingress=%s\ttransit=%s\tegress=%s";
+
+    @Option(name = "-gAddr", aliases = "--groupAddress",
+            description = "IP Address of the multicast group",
+            valueToShowInHelp = "224.0.0.0",
+            required = false, multiValued = false)
+    String gAddr = null;
+
+    @Option(name = "-src", aliases = "--connectPoint",
+            description = "Source port of:XXXXXXXXXX/XX",
+            valueToShowInHelp = "of:0000000000000001/1",
+            required = false, multiValued = false)
+    String source = null;
+
+    @Override
+    protected void execute() {
+        // Verify mcast group
+        IpAddress mcastGroup = null;
+        // We want to use source cp only for a specific group
+        ConnectPoint sourcecp = null;
+        if (!isNullOrEmpty(gAddr)) {
+            mcastGroup = IpAddress.valueOf(gAddr);
+            if (!isNullOrEmpty(source)) {
+                sourcecp = ConnectPoint.deviceConnectPoint(source);
+            }
+        }
+        // Get SR service, the roles and the groups
+        SegmentRoutingService srService = get(SegmentRoutingService.class);
+        Map<McastRoleStoreKey, McastRole> keyToRole = srService.getMcastRoles(mcastGroup, sourcecp);
+        Set<IpAddress> mcastGroups = keyToRole.keySet().stream()
+                .map(McastRoleStoreKey::mcastIp)
+                .collect(Collectors.toSet());
+        // Print the trees for each group
+        mcastGroups.forEach(group -> {
+            // Create a new map for the group
+            Map<ConnectPoint, Multimap<McastRole, DeviceId>> roleDeviceIdMap = Maps.newHashMap();
+            keyToRole.entrySet()
+                    .stream()
+                    .filter(entry -> entry.getKey().mcastIp().equals(group))
+                    .forEach(entry -> roleDeviceIdMap.compute(entry.getKey().source(), (gsource, map) -> {
+                        map = map == null ? ArrayListMultimap.create() : map;
+                        map.put(entry.getValue(), entry.getKey().deviceId());
+                        return map;
+                    }));
+            roleDeviceIdMap.forEach((gsource, map) -> {
+                // Print the map
+                printMcastRole(group, gsource,
+                               map.get(McastRole.INGRESS),
+                               map.get(McastRole.TRANSIT),
+                               map.get(McastRole.EGRESS));
+            });
+        });
+    }
+
+    private void printMcastRole(IpAddress mcastGroup, ConnectPoint source,
+                                Collection<DeviceId> ingress,
+                                Collection<DeviceId> transit,
+                                Collection<DeviceId> egress) {
+        print(FORMAT_MAPPING, mcastGroup, source, ingress, transit, egress);
+    }
+}
diff --git a/app/src/main/java/org/onosproject/segmentrouting/cli/McastTreeListCommand.java b/app/src/main/java/org/onosproject/segmentrouting/cli/McastTreeListCommand.java
index a1ee1f2..89f107a 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/cli/McastTreeListCommand.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/cli/McastTreeListCommand.java
@@ -88,13 +88,15 @@
             }
             Multimap<ConnectPoint, List<ConnectPoint>> mcastTree = srService.getMcastTrees(group,
                                                                                            sourcecp);
-            // Build a json object for each group
-            if (outputJson()) {
-                root.putPOJO(group.toString(), json(mcastTree));
-            } else {
-                // Banner and then the trees
-                printMcastGroup(group);
-                mcastTree.forEach(this::printMcastSink);
+            if (!mcastTree.isEmpty()) {
+                // Build a json object for each group
+                if (outputJson()) {
+                    root.putPOJO(group.toString(), json(mcastTree));
+                } else {
+                    // Banner and then the trees
+                    printMcastGroup(group);
+                    mcastTree.forEach(this::printMcastSink);
+                }
             }
         });
 
diff --git a/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
index 289778e..584e51d 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
@@ -16,6 +16,7 @@
 
 package org.onosproject.segmentrouting.mcast;
 
+import com.google.common.base.Objects;
 import com.google.common.cache.Cache;
 import com.google.common.cache.CacheBuilder;
 import com.google.common.cache.RemovalCause;
@@ -52,11 +53,9 @@
 import org.onosproject.net.topology.TopologyService;
 import org.onosproject.segmentrouting.SRLinkWeigher;
 import org.onosproject.segmentrouting.SegmentRoutingManager;
-import org.onosproject.segmentrouting.storekey.McastStoreKey;
 import org.onosproject.store.serializers.KryoNamespaces;
 import org.onosproject.store.service.ConsistentMap;
 import org.onosproject.store.service.Serializer;
-import org.onosproject.store.service.Versioned;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -93,35 +92,28 @@
  * Handles Multicast related events.
  */
 public class McastHandler {
-    // Logger instance
     private static final Logger log = LoggerFactory.getLogger(McastHandler.class);
-    // Reference to srManager and most used internal objects
     private final SegmentRoutingManager srManager;
     private final TopologyService topologyService;
     private final McastUtils mcastUtils;
-    // Internal store of the Mcast nextobjectives
     private final ConsistentMap<McastStoreKey, NextObjective> mcastNextObjStore;
-    // Internal store of the Mcast roles
-    private final ConsistentMap<McastStoreKey, McastRole> mcastRoleStore;
+    private final ConsistentMap<McastRoleStoreKey, McastRole> mcastRoleStore;
 
     // Wait time for the cache
     private static final int WAIT_TIME_MS = 1000;
 
-    /**
-     * The mcastEventCache is implemented to avoid race condition by giving more time to the
-     * underlying subsystems to process previous calls.
-     */
+    //The mcastEventCache is implemented to avoid race condition by giving more time
+    // to the underlying subsystems to process previous calls.
     private Cache<McastCacheKey, McastEvent> mcastEventCache = CacheBuilder.newBuilder()
             .expireAfterWrite(WAIT_TIME_MS, TimeUnit.MILLISECONDS)
             .removalListener((RemovalNotification<McastCacheKey, McastEvent> notification) -> {
-                // Get group ip, sink and related event
                 IpAddress mcastIp = notification.getKey().mcastIp();
                 HostId sink = notification.getKey().sinkHost();
                 McastEvent mcastEvent = notification.getValue();
                 RemovalCause cause = notification.getCause();
                 log.debug("mcastEventCache removal event. group={}, sink={}, mcastEvent={}, cause={}",
                           mcastIp, sink, mcastEvent, cause);
-                // If it expires or it has been replaced, we deque the event
+                // If it expires or it has been replaced, we deque the event - no when evicted
                 switch (notification.getCause()) {
                     case REPLACED:
                     case EXPIRED:
@@ -133,18 +125,15 @@
             }).build();
 
     private void enqueueMcastEvent(McastEvent mcastEvent) {
-        // Retrieve, currentData, prevData and the group
         final McastRouteUpdate mcastRouteUpdate = mcastEvent.subject();
         final McastRouteUpdate mcastRoutePrevUpdate = mcastEvent.prevSubject();
         final IpAddress group = mcastRoutePrevUpdate.route().group();
-        // Let's create the keys of the cache
         ImmutableSet.Builder<HostId> sinksBuilder = ImmutableSet.builder();
         if (mcastEvent.type() == SOURCES_ADDED ||
                 mcastEvent.type() == SOURCES_REMOVED) {
-            // FIXME To be addressed with multiple sources support
-            sinksBuilder.addAll(Collections.emptySet());
+            // Current subject and prev just differ for the source connect points
+            sinksBuilder.addAll(mcastRouteUpdate.sinks().keySet());
         } else if (mcastEvent.type() == SINKS_ADDED) {
-            // We need to process the host id one by one
             mcastRouteUpdate.sinks().forEach(((hostId, connectPoints) -> {
                 // Get the previous locations and verify if there are changes
                 Set<ConnectPoint> prevConnectPoints = mcastRoutePrevUpdate.sinks().get(hostId);
@@ -155,7 +144,6 @@
                 }
             }));
         } else if (mcastEvent.type() == SINKS_REMOVED) {
-            // We need to process the host id one by one
             mcastRoutePrevUpdate.sinks().forEach(((hostId, connectPoints) -> {
                 // Get the current locations and verify if there are changes
                 Set<ConnectPoint> currentConnectPoints = mcastRouteUpdate.sinks().get(hostId);
@@ -169,7 +157,6 @@
             // Current subject is null, just take the previous host ids
             sinksBuilder.addAll(mcastRoutePrevUpdate.sinks().keySet());
         }
-        // Push the elements in the cache
         sinksBuilder.build().forEach(sink -> {
             McastCacheKey cacheKey = new McastCacheKey(group, sink);
             mcastEventCache.put(cacheKey, mcastEvent);
@@ -177,55 +164,35 @@
     }
 
     private void dequeueMcastEvent(McastEvent mcastEvent) {
-        // Get new and old data
         final McastRouteUpdate mcastUpdate = mcastEvent.subject();
         final McastRouteUpdate mcastPrevUpdate = mcastEvent.prevSubject();
-        // Get source, mcast group
-        // FIXME To be addressed with multiple sources support
-        final ConnectPoint source = mcastPrevUpdate.sources()
-                .values()
-                .stream()
-                .flatMap(Collection::stream)
-                .findFirst()
-                .orElse(null);
         IpAddress mcastIp = mcastPrevUpdate.route().group();
-        // Get all the previous sinks
         Set<ConnectPoint> prevSinks = mcastPrevUpdate.sinks()
-                .values()
-                .stream()
-                .flatMap(Collection::stream)
-                .collect(Collectors.toSet());
-        // According to the event type let's call the proper method
+                .values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
+        Set<ConnectPoint> prevSources = mcastPrevUpdate.sources()
+                .values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
+        Set<ConnectPoint> sources;
         switch (mcastEvent.type()) {
             case SOURCES_ADDED:
-                // FIXME To be addressed with multiple sources support
-                // Get all the sinks
-                //Set<ConnectPoint> sinks = mcastRouteInfo.sinks();
-                // Compute the Mcast tree
-                //Map<ConnectPoint, List<Path>> mcasTree = computeSinkMcastTree(source.deviceId(), sinks);
-                // Process the given sinks using the pre-computed paths
-                //mcasTree.forEach((sink, paths) -> processSinkAddedInternal(source, sink, mcastIp, paths));
+                sources = mcastUpdate.sources()
+                        .values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
+                Set<ConnectPoint> sourcesToBeAdded = Sets.difference(sources, prevSources);
+                processSourcesAddedInternal(sourcesToBeAdded, mcastIp, mcastUpdate.sinks());
                 break;
             case SOURCES_REMOVED:
-                // FIXME To be addressed with multiple sources support
-                // Get old source
-                //ConnectPoint oldSource = mcastEvent.prevSubject().source().orElse(null);
-                // Just the first cached element will be processed
-                //processSourceUpdatedInternal(mcastIp, source, oldSource);
+                sources = mcastUpdate.sources()
+                        .values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
+                Set<ConnectPoint> sourcesToBeRemoved = Sets.difference(prevSources, sources);
+                processSourcesRemovedInternal(sourcesToBeRemoved, sources, mcastIp, mcastUpdate.sinks());
                 break;
             case ROUTE_REMOVED:
-                // Process the route removed, just the first cached element will be processed
-                processRouteRemovedInternal(source, mcastIp);
+                processRouteRemovedInternal(prevSources, mcastIp);
                 break;
             case SINKS_ADDED:
-                // FIXME To be addressed with multiple sources support
-                processSinksAddedInternal(source, mcastIp,
-                                          mcastUpdate.sinks(), prevSinks);
+                processSinksAddedInternal(prevSources, mcastIp, mcastUpdate.sinks(), prevSinks);
                 break;
             case SINKS_REMOVED:
-                // FIXME To be addressed with multiple sources support
-                processSinksRemovedInternal(source, mcastIp,
-                                            mcastUpdate.sinks(), mcastPrevUpdate.sinks());
+                processSinksRemovedInternal(prevSources, mcastIp, mcastUpdate.sinks(), mcastPrevUpdate.sinks());
                 break;
             default:
                 break;
@@ -234,21 +201,12 @@
 
     // Mcast lock to serialize local operations
     private final Lock mcastLock = new ReentrantLock();
-
-    /**
-     * Acquires the lock used when making mcast changes.
-     */
     private void mcastLock() {
         mcastLock.lock();
     }
-
-    /**
-     * Releases the lock used when making mcast changes.
-     */
     private void mcastUnlock() {
         mcastLock.unlock();
     }
-
     // Stability threshold for Mcast. Seconds
     private static final long MCAST_STABLITY_THRESHOLD = 5;
     // Last change done
@@ -268,10 +226,9 @@
         return (now - last) > MCAST_STABLITY_THRESHOLD;
     }
 
-    // Verify interval for Mcast
+    // Verify interval for Mcast bucket corrector
     private static final long MCAST_VERIFY_INTERVAL = 30;
-
-    // Executor for mcast bucket corrector
+    // Executor for mcast bucket corrector and for cache
     private ScheduledExecutorService executorService
             = newScheduledThreadPool(1, groupedThreads("mcastWorker", "mcastWorker-%d", log));
 
@@ -286,63 +243,67 @@
         this.topologyService = srManager.topologyService;
         KryoNamespace.Builder mcastKryo = new KryoNamespace.Builder()
                 .register(KryoNamespaces.API)
-                .register(McastStoreKey.class)
-                .register(McastRole.class);
+                .register(new McastStoreKeySerializer(), McastStoreKey.class);
         mcastNextObjStore = srManager.storageService
                 .<McastStoreKey, NextObjective>consistentMapBuilder()
                 .withName("onos-mcast-nextobj-store")
                 .withSerializer(Serializer.using(mcastKryo.build("McastHandler-NextObj")))
                 .build();
+        mcastKryo = new KryoNamespace.Builder()
+                .register(KryoNamespaces.API)
+                .register(new McastRoleStoreKeySerializer(), McastRoleStoreKey.class)
+                .register(McastRole.class);
         mcastRoleStore = srManager.storageService
-                .<McastStoreKey, McastRole>consistentMapBuilder()
+                .<McastRoleStoreKey, McastRole>consistentMapBuilder()
                 .withName("onos-mcast-role-store")
                 .withSerializer(Serializer.using(mcastKryo.build("McastHandler-Role")))
                 .build();
-        // Let's create McastUtils object
         mcastUtils = new McastUtils(srManager, coreAppId, log);
-        // Init the executor service and the buckets corrector
+        // Init the executor service, the buckets corrector and schedule the clean up
         executorService.scheduleWithFixedDelay(new McastBucketCorrector(), 10,
                                                MCAST_VERIFY_INTERVAL, TimeUnit.SECONDS);
-        // Schedule the clean up, this will allow the processing of the expired events
         executorService.scheduleAtFixedRate(mcastEventCache::cleanUp, 0,
                                             WAIT_TIME_MS, TimeUnit.MILLISECONDS);
     }
 
     /**
-     * Read initial multicast from mcast store.
+     * Read initial multicast configuration from mcast store.
      */
     public void init() {
         lastMcastChange = Instant.now();
         mcastLock();
         try {
             srManager.multicastRouteService.getRoutes().forEach(mcastRoute -> {
-                // Verify leadership on the operation
+                log.debug("Init group {}", mcastRoute.group());
                 if (!mcastUtils.isLeader(mcastRoute.group())) {
                     log.debug("Skip {} due to lack of leadership", mcastRoute.group());
                     return;
                 }
-                // FIXME To be addressed with multiple sources support
-                ConnectPoint source = srManager.multicastRouteService.sources(mcastRoute)
-                        .stream()
-                        .findFirst()
-                        .orElse(null);
-                // Get all the sinks and process them
                 McastRouteData mcastRouteData = srManager.multicastRouteService.routeData(mcastRoute);
-                Set<ConnectPoint> sinks = processSinksToBeAdded(source, mcastRoute.group(), mcastRouteData.sinks());
-                // Filter out all the working sinks, we do not want to move them
-                sinks = sinks.stream()
-                        .filter(sink -> {
-                            McastStoreKey mcastKey = new McastStoreKey(mcastRoute.group(), sink.deviceId());
-                            Versioned<NextObjective> verMcastNext = mcastNextObjStore.get(mcastKey);
-                            return verMcastNext == null ||
-                                    !mcastUtils.getPorts(verMcastNext.value().next()).contains(sink.port());
-                        })
-                        .collect(Collectors.toSet());
-                // Compute the Mcast tree
-                Map<ConnectPoint, List<Path>> mcasTree = computeSinkMcastTree(source.deviceId(), sinks);
-                // Process the given sinks using the pre-computed paths
-                mcasTree.forEach((sink, paths) -> processSinkAddedInternal(source, sink,
-                                                                           mcastRoute.group(), paths));
+                // For each source process the mcast tree
+                srManager.multicastRouteService.sources(mcastRoute).forEach(source -> {
+                    Map<ConnectPoint, List<ConnectPoint>> mcastPaths = Maps.newHashMap();
+                    Set<DeviceId> visited = Sets.newHashSet();
+                    List<ConnectPoint> currentPath = Lists.newArrayList(source);
+                    buildMcastPaths(source.deviceId(), visited, mcastPaths,
+                                    currentPath, mcastRoute.group(), source);
+                    // Get all the sinks and process them
+                    Set<ConnectPoint> sinks = processSinksToBeAdded(source, mcastRoute.group(),
+                                                                    mcastRouteData.sinks());
+                    // Filter out all the working sinks, we do not want to move them
+                    // TODO we need a better way to distinguish flows coming from different sources
+                    sinks = sinks.stream()
+                            .filter(sink -> !mcastPaths.containsKey(sink) ||
+                                    !isSinkForSource(mcastRoute.group(), sink, source))
+                            .collect(Collectors.toSet());
+                    if (sinks.isEmpty()) {
+                        log.debug("Skip {} for source {} nothing to do", mcastRoute.group(), source);
+                        return;
+                    }
+                    Map<ConnectPoint, List<Path>> mcasTree = computeSinkMcastTree(source.deviceId(), sinks);
+                    mcasTree.forEach((sink, paths) -> processSinkAddedInternal(source, sink,
+                                                                               mcastRoute.group(), paths));
+                });
             });
         } finally {
             mcastUnlock();
@@ -363,7 +324,7 @@
 
     /**
      * Processes the SOURCE_ADDED, SOURCE_UPDATED, SINK_ADDED,
-     * SINK_REMOVED and ROUTE_REMOVED events.
+     * SINK_REMOVED, ROUTE_ADDED and ROUTE_REMOVED events.
      *
      * @param event McastEvent with SOURCE_ADDED type
      */
@@ -371,15 +332,165 @@
         log.info("process {}", event);
         // If it is a route added, we do not enqueue
         if (event.type() == ROUTE_ADDED) {
-            // We need just to elect a leader
             processRouteAddedInternal(event.subject().route().group());
         } else {
-            // Just enqueue for now
             enqueueMcastEvent(event);
         }
     }
 
     /**
+     * Process the SOURCES_ADDED event.
+     *
+     * @param sources the sources connect point
+     * @param mcastIp the group address
+     * @param sinks the sinks connect points
+     */
+    private void processSourcesAddedInternal(Set<ConnectPoint> sources, IpAddress mcastIp,
+                                             Map<HostId, Set<ConnectPoint>> sinks) {
+        lastMcastChange = Instant.now();
+        mcastLock();
+        try {
+            log.debug("Processing sources added {} for group {}", sources, mcastIp);
+            if (!mcastUtils.isLeader(mcastIp)) {
+                log.debug("Skip {} due to lack of leadership", mcastIp);
+                return;
+            }
+            sources.forEach(source -> {
+                Set<ConnectPoint> sinksToBeAdded = processSinksToBeAdded(source, mcastIp, sinks);
+                Map<ConnectPoint, List<Path>> mcasTree = computeSinkMcastTree(source.deviceId(), sinksToBeAdded);
+                mcasTree.forEach((sink, paths) -> processSinkAddedInternal(source, sink, mcastIp, paths));
+            });
+        } finally {
+            mcastUnlock();
+        }
+    }
+
+    /**
+     * Process the SOURCES_REMOVED event.
+     *
+     * @param sourcesToBeRemoved the source connect points to be removed
+     * @param remainingSources the remainig source connect points
+     * @param mcastIp the group address
+     * @param sinks the sinks connect points
+     */
+    private void processSourcesRemovedInternal(Set<ConnectPoint> sourcesToBeRemoved,
+                                               Set<ConnectPoint> remainingSources,
+                                               IpAddress mcastIp,
+                                               Map<HostId, Set<ConnectPoint>> sinks) {
+        lastMcastChange = Instant.now();
+        mcastLock();
+        try {
+            log.debug("Processing sources removed {} for group {}", sourcesToBeRemoved, mcastIp);
+            if (!mcastUtils.isLeader(mcastIp)) {
+                log.debug("Skip {} due to lack of leadership", mcastIp);
+                return;
+            }
+            if (remainingSources.isEmpty()) {
+                processRouteRemovedInternal(sourcesToBeRemoved, mcastIp);
+                return;
+            }
+            // Skip offline devices
+            Set<ConnectPoint> candidateSources = sourcesToBeRemoved.stream()
+                    .filter(source -> srManager.deviceService.isAvailable(source.deviceId()))
+                    .collect(Collectors.toSet());
+            if (candidateSources.isEmpty()) {
+                log.debug("Skip {} due to empty sources to be removed", mcastIp);
+                return;
+            }
+            Set<Link> remainingLinks = Sets.newHashSet();
+            Map<ConnectPoint, Set<Link>> candidateLinks = Maps.newHashMap();
+            Map<ConnectPoint, Set<ConnectPoint>> candidateSinks = Maps.newHashMap();
+            Set<ConnectPoint> totalSources = Sets.newHashSet(candidateSources);
+            totalSources.addAll(remainingSources);
+            // Calculate all the links used by the sources
+            totalSources.forEach(source -> {
+                Set<ConnectPoint> currentSinks = sinks.values()
+                        .stream().flatMap(Collection::stream)
+                        .filter(sink -> isSinkForSource(mcastIp, sink, source))
+                        .collect(Collectors.toSet());
+                candidateSinks.put(source, currentSinks);
+                currentSinks.forEach(currentSink -> {
+                    Optional<Path> currentPath = getPath(source.deviceId(), currentSink.deviceId(),
+                                                         mcastIp, null, source);
+                    if (currentPath.isPresent()) {
+                        if (!candidateSources.contains(source)) {
+                            remainingLinks.addAll(currentPath.get().links());
+                        } else {
+                            candidateLinks.put(source, Sets.newHashSet(currentPath.get().links()));
+                        }
+                    }
+                });
+            });
+            // Clean transit links
+            candidateLinks.forEach((source, currentCandidateLinks) -> {
+                Set<Link> linksToBeRemoved = Sets.difference(currentCandidateLinks, remainingLinks)
+                        .immutableCopy();
+                if (!linksToBeRemoved.isEmpty()) {
+                    currentCandidateLinks.forEach(link -> {
+                        DeviceId srcLink = link.src().deviceId();
+                        // Remove ports only on links to be removed
+                        if (linksToBeRemoved.contains(link)) {
+                            removePortFromDevice(link.src().deviceId(), link.src().port(), mcastIp,
+                                                 mcastUtils.assignedVlan(srcLink.equals(source.deviceId()) ?
+                                                                                 source : null));
+                        }
+                        // Remove role on the candidate links
+                        mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, srcLink, source));
+                    });
+                }
+            });
+            // Clean ingress and egress
+            candidateSources.forEach(source -> {
+                Set<ConnectPoint> currentSinks = candidateSinks.get(source);
+                currentSinks.forEach(currentSink -> {
+                    VlanId assignedVlan = mcastUtils.assignedVlan(source.deviceId().equals(currentSink.deviceId()) ?
+                                                                          source : null);
+                    // Sinks co-located with the source
+                    if (source.deviceId().equals(currentSink.deviceId())) {
+                        if (source.port().equals(currentSink.port())) {
+                            log.warn("Skip {} since sink {} is on the same port of source {}. Abort",
+                                     mcastIp, currentSink, source);
+                            return;
+                        }
+                        // We need to check against the other sources and if it is
+                        // necessary remove the port from the device - no overlap
+                        Set<VlanId> otherVlans = remainingSources.stream()
+                                // Only sources co-located and having this sink
+                                .filter(remainingSource -> remainingSource.deviceId()
+                                        .equals(source.deviceId()) && candidateSinks.get(remainingSource)
+                                        .contains(currentSink))
+                                .map(remainingSource -> mcastUtils.assignedVlan(
+                                        remainingSource.deviceId().equals(currentSink.deviceId()) ?
+                                                remainingSource : null)).collect(Collectors.toSet());
+                        if (!otherVlans.contains(assignedVlan)) {
+                            removePortFromDevice(currentSink.deviceId(), currentSink.port(),
+                                                 mcastIp, assignedVlan);
+                        }
+                        mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, currentSink.deviceId(),
+                                                                    source));
+                        return;
+                    }
+                    Set<VlanId> otherVlans = remainingSources.stream()
+                            .filter(remainingSource -> candidateSinks.get(remainingSource)
+                                    .contains(currentSink))
+                            .map(remainingSource -> mcastUtils.assignedVlan(
+                                    remainingSource.deviceId().equals(currentSink.deviceId()) ?
+                                            remainingSource : null)).collect(Collectors.toSet());
+                    // Sinks on other leaves
+                    if (!otherVlans.contains(assignedVlan)) {
+                        removePortFromDevice(currentSink.deviceId(), currentSink.port(),
+                                             mcastIp, assignedVlan);
+                    }
+                    mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, currentSink.deviceId(),
+                                                                source));
+                });
+            });
+        } finally {
+            mcastUnlock();
+        }
+    }
+
+    /**
      * Process the ROUTE_ADDED event.
      *
      * @param mcastIp the group address
@@ -398,76 +509,92 @@
 
     /**
      * Removes the entire mcast tree related to this group.
-     *
+     * @param sources the source connect points
      * @param mcastIp multicast group IP address
      */
-    private void processRouteRemovedInternal(ConnectPoint source, IpAddress mcastIp) {
+    private void processRouteRemovedInternal(Set<ConnectPoint> sources, IpAddress mcastIp) {
         lastMcastChange = Instant.now();
         mcastLock();
         try {
             log.debug("Processing route removed for group {}", mcastIp);
-            // Verify leadership on the operation
             if (!mcastUtils.isLeader(mcastIp)) {
                 log.debug("Skip {} due to lack of leadership", mcastIp);
                 mcastUtils.withdrawLeader(mcastIp);
                 return;
             }
-
-            // Find out the ingress, transit and egress device of the affected group
-            DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
-                    .stream().findAny().orElse(null);
-            Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
-            Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS);
-
-            // If there are no egress devices, sinks could be only on the ingress
-            if (!egressDevices.isEmpty()) {
-                egressDevices.forEach(
-                        deviceId -> removeGroupFromDevice(deviceId, mcastIp, mcastUtils.assignedVlan(null))
-                );
-            }
-            // Transit could be empty if sinks are on the ingress
-            if (!transitDevices.isEmpty()) {
-                transitDevices.forEach(
-                        deviceId -> removeGroupFromDevice(deviceId, mcastIp, mcastUtils.assignedVlan(null))
-                );
-            }
-            // Ingress device should be not null
-            if (ingressDevice != null) {
-                removeGroupFromDevice(ingressDevice, mcastIp, mcastUtils.assignedVlan(source));
-            }
+            sources.forEach(source -> {
+                // Find out the ingress, transit and egress device of the affected group
+                DeviceId ingressDevice = getDevice(mcastIp, INGRESS, source)
+                        .stream().findFirst().orElse(null);
+                Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT, source);
+                Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS, source);
+                // If there are no egress and transit devices, sinks could be only on the ingress
+                if (!egressDevices.isEmpty()) {
+                    egressDevices.forEach(deviceId -> {
+                        removeGroupFromDevice(deviceId, mcastIp, mcastUtils.assignedVlan(null));
+                        mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, deviceId, source));
+                    });
+                }
+                if (!transitDevices.isEmpty()) {
+                    transitDevices.forEach(deviceId -> {
+                        removeGroupFromDevice(deviceId, mcastIp, mcastUtils.assignedVlan(null));
+                        mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, deviceId, source));
+                    });
+                }
+                if (ingressDevice != null) {
+                    removeGroupFromDevice(ingressDevice, mcastIp, mcastUtils.assignedVlan(source));
+                    mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, ingressDevice, source));
+                }
+            });
+            // Finally, withdraw the leadership
+            mcastUtils.withdrawLeader(mcastIp);
         } finally {
             mcastUnlock();
         }
     }
 
-
     /**
      * Process sinks to be removed.
      *
-     * @param source the source connect point
+     * @param sources the source connect points
      * @param mcastIp the ip address of the group
      * @param newSinks the new sinks to be processed
      * @param prevSinks the previous sinks
      */
-    private void processSinksRemovedInternal(ConnectPoint source, IpAddress mcastIp,
+    private void processSinksRemovedInternal(Set<ConnectPoint> sources, IpAddress mcastIp,
                                              Map<HostId, Set<ConnectPoint>> newSinks,
                                              Map<HostId, Set<ConnectPoint>> prevSinks) {
         lastMcastChange = Instant.now();
         mcastLock();
         try {
-            // Verify leadership on the operation
             if (!mcastUtils.isLeader(mcastIp)) {
                 log.debug("Skip {} due to lack of leadership", mcastIp);
                 return;
             }
-            // Remove the previous ones
-            Set<ConnectPoint> sinksToBeRemoved = processSinksToBeRemoved(mcastIp, prevSinks,
-                                                                         newSinks);
-            sinksToBeRemoved.forEach(sink -> processSinkRemovedInternal(source, sink, mcastIp));
-            // Recover the dual-homed sinks
-            Set<ConnectPoint> sinksToBeRecovered = processSinksToBeRecovered(mcastIp, newSinks,
-                                                                             prevSinks);
-            sinksToBeRecovered.forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, null));
+            log.debug("Processing sinks removed for group {} and for sources {}",
+                      mcastIp, sources);
+            Map<ConnectPoint, Map<ConnectPoint, Optional<Path>>> treesToBeRemoved = Maps.newHashMap();
+            Map<ConnectPoint, Set<ConnectPoint>> treesToBeAdded = Maps.newHashMap();
+            sources.forEach(source -> {
+                // Save the path associated to the sinks to be removed
+                Set<ConnectPoint> sinksToBeRemoved = processSinksToBeRemoved(mcastIp, prevSinks,
+                                                                             newSinks, source);
+                Map<ConnectPoint, Optional<Path>> treeToBeRemoved = Maps.newHashMap();
+                sinksToBeRemoved.forEach(sink -> treeToBeRemoved.put(sink, getPath(source.deviceId(),
+                                                                                   sink.deviceId(), mcastIp,
+                                                                                   null, source)));
+                treesToBeRemoved.put(source, treeToBeRemoved);
+                // Recover the dual-homed sinks
+                Set<ConnectPoint> sinksToBeRecovered = processSinksToBeRecovered(mcastIp, newSinks,
+                                                                                 prevSinks, source);
+                treesToBeAdded.put(source, sinksToBeRecovered);
+            });
+            // Remove the sinks taking into account the multiple sources and the original paths
+            treesToBeRemoved.forEach((source, tree) ->
+                tree.forEach((sink, path) -> processSinkRemovedInternal(source, sink, mcastIp, path)));
+            // Add new sinks according to the recovery procedure
+            treesToBeAdded.forEach((source, sinks) ->
+                sinks.forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, null)));
         } finally {
             mcastUnlock();
         }
@@ -479,51 +606,43 @@
      * @param source connect point of the multicast source
      * @param sink connection point of the multicast sink
      * @param mcastIp multicast group IP address
+     * @param mcastPath path associated to the sink
      */
     private void processSinkRemovedInternal(ConnectPoint source, ConnectPoint sink,
-                                            IpAddress mcastIp) {
+                                            IpAddress mcastIp, Optional<Path> mcastPath) {
         lastMcastChange = Instant.now();
         mcastLock();
         try {
+            log.debug("Processing sink removed {} for group {} and for source {}", sink, mcastIp, source);
             boolean isLast;
             // When source and sink are on the same device
             if (source.deviceId().equals(sink.deviceId())) {
                 // Source and sink are on even the same port. There must be something wrong.
                 if (source.port().equals(sink.port())) {
-                    log.warn("Skip {} since sink {} is on the same port of source {}. Abort",
-                             mcastIp, sink, source);
+                    log.warn("Skip {} since sink {} is on the same port of source {}. Abort", mcastIp, sink, source);
                     return;
                 }
                 isLast = removePortFromDevice(sink.deviceId(), sink.port(), mcastIp, mcastUtils.assignedVlan(source));
                 if (isLast) {
-                    mcastRoleStore.remove(new McastStoreKey(mcastIp, sink.deviceId()));
+                    mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, sink.deviceId(), source));
                 }
                 return;
             }
-
             // Process the egress device
             isLast = removePortFromDevice(sink.deviceId(), sink.port(), mcastIp, mcastUtils.assignedVlan(null));
             if (isLast) {
-                mcastRoleStore.remove(new McastStoreKey(mcastIp, sink.deviceId()));
+                mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, sink.deviceId(), source));
             }
-
             // If this is the last sink on the device, also update upstream
-            Optional<Path> mcastPath = getPath(source.deviceId(), sink.deviceId(),
-                                               mcastIp, null);
             if (mcastPath.isPresent()) {
                 List<Link> links = Lists.newArrayList(mcastPath.get().links());
                 Collections.reverse(links);
                 for (Link link : links) {
                     if (isLast) {
-                        isLast = removePortFromDevice(
-                                link.src().deviceId(),
-                                link.src().port(),
-                                mcastIp,
-                                mcastUtils.assignedVlan(link.src().deviceId().equals(source.deviceId()) ?
-                                                     source : null)
-                        );
+                        isLast = removePortFromDevice(link.src().deviceId(), link.src().port(), mcastIp,
+                        mcastUtils.assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
                         if (isLast) {
-                            mcastRoleStore.remove(new McastStoreKey(mcastIp, link.src().deviceId()));
+                            mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, link.src().deviceId(), source));
                         }
                     }
                 }
@@ -537,27 +656,27 @@
     /**
      * Process sinks to be added.
      *
-     * @param source the source connect point
+     * @param sources the source connect points
      * @param mcastIp the group IP
      * @param newSinks the new sinks to be processed
      * @param allPrevSinks all previous sinks
      */
-    private void processSinksAddedInternal(ConnectPoint source, IpAddress mcastIp,
+    private void processSinksAddedInternal(Set<ConnectPoint> sources, IpAddress mcastIp,
                                            Map<HostId, Set<ConnectPoint>> newSinks,
                                            Set<ConnectPoint> allPrevSinks) {
         lastMcastChange = Instant.now();
         mcastLock();
         try {
-            // Verify leadership on the operation
             if (!mcastUtils.isLeader(mcastIp)) {
                 log.debug("Skip {} due to lack of leadership", mcastIp);
                 return;
             }
-            // Get the only sinks to be processed (new ones)
-            Set<ConnectPoint> sinksToBeAdded = processSinksToBeAdded(source, mcastIp, newSinks);
-            // Install new sinks
-            sinksToBeAdded = Sets.difference(sinksToBeAdded, allPrevSinks);
-            sinksToBeAdded.forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, null));
+            log.debug("Processing sinks added for group {} and for sources {}", mcastIp, sources);
+            sources.forEach(source -> {
+                Set<ConnectPoint> sinksToBeAdded = processSinksToBeAdded(source, mcastIp, newSinks);
+                sinksToBeAdded = Sets.difference(sinksToBeAdded, allPrevSinks);
+                sinksToBeAdded.forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, null));
+            });
         } finally {
             mcastUnlock();
         }
@@ -575,34 +694,27 @@
         lastMcastChange = Instant.now();
         mcastLock();
         try {
+            log.debug("Processing sink added {} for group {} and for source {}", sink, mcastIp, source);
             // Process the ingress device
             mcastUtils.addFilterToDevice(source.deviceId(), source.port(),
-                              mcastUtils.assignedVlan(source), mcastIp, INGRESS);
-
-            // When source and sink are on the same device
+                                         mcastUtils.assignedVlan(source), mcastIp, INGRESS);
             if (source.deviceId().equals(sink.deviceId())) {
-                // Source and sink are on even the same port. There must be something wrong.
                 if (source.port().equals(sink.port())) {
                     log.warn("Skip {} since sink {} is on the same port of source {}. Abort",
                              mcastIp, sink, source);
                     return;
                 }
                 addPortToDevice(sink.deviceId(), sink.port(), mcastIp, mcastUtils.assignedVlan(source));
-                mcastRoleStore.put(new McastStoreKey(mcastIp, sink.deviceId()), INGRESS);
+                mcastRoleStore.put(new McastRoleStoreKey(mcastIp, sink.deviceId(), source), INGRESS);
                 return;
             }
-
             // Find a path. If present, create/update groups and flows for each hop
-            Optional<Path> mcastPath = getPath(source.deviceId(), sink.deviceId(),
-                                               mcastIp, allPaths);
+            Optional<Path> mcastPath = getPath(source.deviceId(), sink.deviceId(), mcastIp, allPaths, source);
             if (mcastPath.isPresent()) {
                 List<Link> links = mcastPath.get().links();
-
                 // Setup mcast role for ingress
-                mcastRoleStore.put(new McastStoreKey(mcastIp, source.deviceId()),
-                                   INGRESS);
-
-                // Setup properly the transit
+                mcastRoleStore.put(new McastRoleStoreKey(mcastIp, source.deviceId(), source), INGRESS);
+                // Setup properly the transit forwarding
                 links.forEach(link -> {
                     addPortToDevice(link.src().deviceId(), link.src().port(), mcastIp,
                                     mcastUtils.assignedVlan(link.src().deviceId()
@@ -610,21 +722,17 @@
                     mcastUtils.addFilterToDevice(link.dst().deviceId(), link.dst().port(),
                                       mcastUtils.assignedVlan(null), mcastIp, null);
                 });
-
                 // Setup mcast role for the transit
                 links.stream()
                         .filter(link -> !link.dst().deviceId().equals(sink.deviceId()))
-                        .forEach(link -> mcastRoleStore.put(new McastStoreKey(mcastIp, link.dst().deviceId()),
-                                                            TRANSIT));
-
+                        .forEach(link -> mcastRoleStore.put(new McastRoleStoreKey(mcastIp, link.dst().deviceId(),
+                                                                                  source), TRANSIT));
                 // Process the egress device
                 addPortToDevice(sink.deviceId(), sink.port(), mcastIp, mcastUtils.assignedVlan(null));
                 // Setup mcast role for egress
-                mcastRoleStore.put(new McastStoreKey(mcastIp, sink.deviceId()),
-                                   EGRESS);
+                mcastRoleStore.put(new McastRoleStoreKey(mcastIp, sink.deviceId(), source), EGRESS);
             } else {
-                log.warn("Unable to find a path from {} to {}. Abort sinkAdded",
-                         source.deviceId(), sink.deviceId());
+                log.warn("Unable to find a path from {} to {}. Abort sinkAdded", source.deviceId(), sink.deviceId());
             }
         } finally {
             mcastUnlock();
@@ -642,77 +750,8 @@
         try {
             // Get groups affected by the link down event
             getAffectedGroups(affectedLink).forEach(mcastIp -> {
-                // TODO Optimize when the group editing is in place
-                log.debug("Processing link down {} for group {}",
-                          affectedLink, mcastIp);
-                // Verify leadership on the operation
-                if (!mcastUtils.isLeader(mcastIp)) {
-                    log.debug("Skip {} due to lack of leadership", mcastIp);
-                    return;
-                }
-
-                // Find out the ingress, transit and egress device of affected group
-                DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
-                        .stream().findAny().orElse(null);
-                Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
-                Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS);
-                ConnectPoint source = mcastUtils.getSource(mcastIp);
-
-                // Do not proceed if ingress device or source of this group are missing
-                // If sinks are in other leafs, we have ingress, transit, egress, and source
-                // If sinks are in the same leaf, we have just ingress and source
-                if (ingressDevice == null || source == null) {
-                    log.warn("Missing ingress {} or source {} for group {}",
-                             ingressDevice, source, mcastIp);
-                    return;
-                }
-
-                // Remove entire transit
-                transitDevices.forEach(transitDevice ->
-                                removeGroupFromDevice(transitDevice, mcastIp,
-                                                      mcastUtils.assignedVlan(null)));
-
-                // Remove transit-facing ports on the ingress device
-                removeIngressTransitPorts(mcastIp, ingressDevice, source);
-
-                // TODO create a shared procedure with DEVICE_DOWN
-                // Compute mcast tree for the the egress devices
-                Map<DeviceId, List<Path>> mcastTree = computeMcastTree(ingressDevice, egressDevices);
-
-                // We have to verify, if there are egresses without paths
-                Set<DeviceId> notRecovered = Sets.newHashSet();
-                mcastTree.forEach((egressDevice, paths) -> {
-                    // Let's check if there is at least a path
-                    Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
-                                                       mcastIp, paths);
-                    // No paths, we have to try with alternative location
-                    if (!mcastPath.isPresent()) {
-                        notRecovered.add(egressDevice);
-                        // We were not able to find an alternative path for this egress
-                        log.warn("Fail to recover egress device {} from link failure {}",
-                                 egressDevice, affectedLink);
-                        removeGroupFromDevice(egressDevice, mcastIp,
-                                              mcastUtils.assignedVlan(null));
-                    }
-                });
-
-                // Fast path, we can recover all the locations
-                if (notRecovered.isEmpty()) {
-                    // Construct a new path for each egress device
-                    mcastTree.forEach((egressDevice, paths) -> {
-                        // We try to enforce the sinks path on the mcast tree
-                        Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
-                                                           mcastIp, paths);
-                        // If a path is present, let's install it
-                        if (mcastPath.isPresent()) {
-                            installPath(mcastIp, source, mcastPath.get());
-                        }
-                    });
-                } else {
-                    // Let's try to recover using alternate
-                    recoverSinks(egressDevices, notRecovered, mcastIp,
-                                 ingressDevice, source, true);
-                }
+                log.debug("Processing link down {} for group {}", affectedLink, mcastIp);
+                recoverFailure(mcastIp, affectedLink);
             });
         } finally {
             mcastUnlock();
@@ -730,105 +769,8 @@
         try {
             // Get the mcast groups affected by the device going down
             getAffectedGroups(deviceDown).forEach(mcastIp -> {
-                // TODO Optimize when the group editing is in place
-                log.debug("Processing device down {} for group {}",
-                          deviceDown, mcastIp);
-                // Verify leadership on the operation
-                if (!mcastUtils.isLeader(mcastIp)) {
-                    log.debug("Skip {} due to lack of leadership", mcastIp);
-                    return;
-                }
-
-                // Find out the ingress, transit and egress device of affected group
-                DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
-                        .stream().findAny().orElse(null);
-                Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
-                Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS);
-                ConnectPoint source = mcastUtils.getSource(mcastIp);
-
-                // Do not proceed if ingress device or source of this group are missing
-                // If sinks are in other leafs, we have ingress, transit, egress, and source
-                // If sinks are in the same leaf, we have just ingress and source
-                if (ingressDevice == null || source == null) {
-                    log.warn("Missing ingress {} or source {} for group {}",
-                             ingressDevice, source, mcastIp);
-                    return;
-                }
-
-                // If it exists, we have to remove it in any case
-                if (!transitDevices.isEmpty()) {
-                    // Remove entire transit
-                    transitDevices.forEach(transitDevice ->
-                                    removeGroupFromDevice(transitDevice, mcastIp,
-                                                          mcastUtils.assignedVlan(null)));
-                }
-                // If the ingress is down
-                if (ingressDevice.equals(deviceDown)) {
-                    // Remove entire ingress
-                    removeGroupFromDevice(ingressDevice, mcastIp, mcastUtils.assignedVlan(source));
-                    // If other sinks different from the ingress exist
-                    if (!egressDevices.isEmpty()) {
-                        // Remove all the remaining egress
-                        egressDevices.forEach(
-                                egressDevice -> removeGroupFromDevice(egressDevice, mcastIp,
-                                                                      mcastUtils.assignedVlan(null))
-                        );
-                    }
-                } else {
-                    // Egress or transit could be down at this point
-                    // Get the ingress-transit ports if they exist
-                    removeIngressTransitPorts(mcastIp, ingressDevice, source);
-
-                    // One of the egress device is down
-                    if (egressDevices.contains(deviceDown)) {
-                        // Remove entire device down
-                        removeGroupFromDevice(deviceDown, mcastIp, mcastUtils.assignedVlan(null));
-                        // Remove the device down from egress
-                        egressDevices.remove(deviceDown);
-                        // If there are no more egress and ingress does not have sinks
-                        if (egressDevices.isEmpty() && !hasSinks(ingressDevice, mcastIp)) {
-                            // We have done
-                            return;
-                        }
-                    }
-
-                    // Compute mcast tree for the the egress devices
-                    Map<DeviceId, List<Path>> mcastTree = computeMcastTree(ingressDevice, egressDevices);
-
-                    // We have to verify, if there are egresses without paths
-                    Set<DeviceId> notRecovered = Sets.newHashSet();
-                    mcastTree.forEach((egressDevice, paths) -> {
-                        // Let's check if there is at least a path
-                        Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
-                                                           mcastIp, paths);
-                        // No paths, we have to try with alternative location
-                        if (!mcastPath.isPresent()) {
-                            notRecovered.add(egressDevice);
-                            // We were not able to find an alternative path for this egress
-                            log.warn("Fail to recover egress device {} from device down {}",
-                                     egressDevice, deviceDown);
-                            removeGroupFromDevice(egressDevice, mcastIp, mcastUtils.assignedVlan(null));
-                        }
-                    });
-
-                    // Fast path, we can recover all the locations
-                    if (notRecovered.isEmpty()) {
-                        // Construct a new path for each egress device
-                        mcastTree.forEach((egressDevice, paths) -> {
-                            // We try to enforce the sinks path on the mcast tree
-                            Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
-                                                               mcastIp, paths);
-                            // If a path is present, let's install it
-                            if (mcastPath.isPresent()) {
-                                installPath(mcastIp, source, mcastPath.get());
-                            }
-                        });
-                    } else {
-                        // Let's try to recover using alternate
-                        recoverSinks(egressDevices, notRecovered, mcastIp,
-                                     ingressDevice, source, false);
-                    }
-                }
+                log.debug("Processing device down {} for group {}", deviceDown, mcastIp);
+                recoverFailure(mcastIp, deviceDown);
             });
         } finally {
             mcastUnlock();
@@ -836,6 +778,100 @@
     }
 
     /**
+     * General failure recovery procedure.
+     *
+     * @param mcastIp the group to recover
+     * @param failedElement the failed element
+     */
+    private void recoverFailure(IpAddress mcastIp, Object failedElement) {
+        // TODO Optimize when the group editing is in place
+        if (!mcastUtils.isLeader(mcastIp)) {
+            log.debug("Skip {} due to lack of leadership", mcastIp);
+            return;
+        }
+        // Do not proceed if the sources of this group are missing
+        Set<ConnectPoint> sources = getSources(mcastIp);
+        if (sources.isEmpty()) {
+            log.warn("Missing sources for group {}", mcastIp);
+            return;
+        }
+        // Find out the ingress devices of the affected group
+        // If sinks are in other leafs, we have ingress, transit, egress, and source
+        // If sinks are in the same leaf, we have just ingress and source
+        Set<DeviceId> ingressDevices = getDevice(mcastIp, INGRESS);
+        if (ingressDevices.isEmpty()) {
+            log.warn("Missing ingress devices for group {}", ingressDevices, mcastIp);
+            return;
+        }
+        // For each tree, delete ingress-transit part
+        sources.forEach(source -> {
+            Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT, source);
+            transitDevices.forEach(transitDevice -> {
+                removeGroupFromDevice(transitDevice, mcastIp, mcastUtils.assignedVlan(null));
+                mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, transitDevice, source));
+            });
+        });
+        removeIngressTransitPorts(mcastIp, ingressDevices, sources);
+        // TODO Evaluate the possibility of building optimize trees between sources
+        Map<DeviceId, Set<ConnectPoint>> notRecovered = Maps.newHashMap();
+        sources.forEach(source -> {
+            Set<DeviceId> notRecoveredInternal = Sets.newHashSet();
+            DeviceId ingressDevice = ingressDevices.stream()
+                    .filter(deviceId -> deviceId.equals(source.deviceId())).findFirst().orElse(null);
+            // Clean also the ingress
+            if (failedElement instanceof DeviceId && ingressDevice.equals(failedElement)) {
+                removeGroupFromDevice((DeviceId) failedElement, mcastIp, mcastUtils.assignedVlan(source));
+                mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, (DeviceId) failedElement, source));
+            }
+            if (ingressDevice == null) {
+                log.warn("Skip failure recovery - " +
+                                 "Missing ingress for source {} and group {}", source, mcastIp);
+                return;
+            }
+            Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS, source);
+            Map<DeviceId, List<Path>> mcastTree = computeMcastTree(ingressDevice, egressDevices);
+            // We have to verify, if there are egresses without paths
+            mcastTree.forEach((egressDevice, paths) -> {
+                Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
+                                                   mcastIp, paths, source);
+                // No paths, we have to try with alternative location
+                if (!mcastPath.isPresent()) {
+                    notRecovered.compute(egressDevice, (deviceId, listSources) -> {
+                        listSources = listSources == null ? Sets.newHashSet() : listSources;
+                        listSources.add(source);
+                        return listSources;
+                    });
+                    notRecoveredInternal.add(egressDevice);
+                }
+            });
+            // Fast path, we can recover all the locations
+            if (notRecoveredInternal.isEmpty()) {
+                mcastTree.forEach((egressDevice, paths) -> {
+                    Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
+                                                       mcastIp, paths, source);
+                    if (mcastPath.isPresent()) {
+                        installPath(mcastIp, source, mcastPath.get());
+                    }
+                });
+            } else {
+                // Let's try to recover using alternative locations
+                recoverSinks(egressDevices, notRecoveredInternal, mcastIp,
+                             ingressDevice, source);
+            }
+        });
+        // Finally remove the egresses not recovered
+        notRecovered.forEach((egressDevice, listSources) -> {
+            Set<ConnectPoint> currentSources = getSources(mcastIp, egressDevice, EGRESS);
+            if (Objects.equal(currentSources, listSources)) {
+                log.warn("Fail to recover egress device {} from {} failure {}",
+                         egressDevice, failedElement instanceof Link ? "Link" : "Device", failedElement);
+                removeGroupFromDevice(egressDevice, mcastIp, mcastUtils.assignedVlan(null));
+            }
+            listSources.forEach(source -> mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, egressDevice, source)));
+        });
+    }
+
+    /**
      * Try to recover sinks using alternate locations.
      *
      * @param egressDevices the original egress devices
@@ -843,65 +879,44 @@
      * @param mcastIp the group address
      * @param ingressDevice the ingress device
      * @param source the source connect point
-     * @param isLinkFailure true if it is a link failure, otherwise false
      */
     private void recoverSinks(Set<DeviceId> egressDevices, Set<DeviceId> notRecovered,
-                              IpAddress mcastIp, DeviceId ingressDevice, ConnectPoint source,
-                              boolean isLinkFailure) {
-        // Recovered devices
+                              IpAddress mcastIp, DeviceId ingressDevice, ConnectPoint source) {
+        log.debug("Processing recover sinks for group {} and for source {}",
+                  mcastIp, source);
         Set<DeviceId> recovered = Sets.difference(egressDevices, notRecovered);
-        // Total affected sinks
         Set<ConnectPoint> totalAffectedSinks = Sets.newHashSet();
-        // Total sinks
         Set<ConnectPoint> totalSinks = Sets.newHashSet();
         // Let's compute all the affected sinks and all the sinks
         notRecovered.forEach(deviceId -> {
             totalAffectedSinks.addAll(
-                    mcastUtils.getAffectedSinks(deviceId, mcastIp)
-                            .values()
-                            .stream()
+                    mcastUtils.getAffectedSinks(deviceId, mcastIp).values().stream()
                             .flatMap(Collection::stream)
                             .filter(connectPoint -> connectPoint.deviceId().equals(deviceId))
                             .collect(Collectors.toSet())
             );
             totalSinks.addAll(
-                    mcastUtils.getAffectedSinks(deviceId, mcastIp)
-                            .values()
-                            .stream()
-                            .flatMap(Collection::stream)
-                            .collect(Collectors.toSet())
+                    mcastUtils.getAffectedSinks(deviceId, mcastIp).values().stream()
+                            .flatMap(Collection::stream).collect(Collectors.toSet())
             );
         });
-
-        // Sinks to be added
         Set<ConnectPoint> sinksToBeAdded = Sets.difference(totalSinks, totalAffectedSinks);
-        // New egress devices, filtering out the source
-        Set<DeviceId> newEgressDevice = sinksToBeAdded.stream()
-                .map(ConnectPoint::deviceId)
-                .collect(Collectors.toSet());
-        // Let's add the devices recovered from the previous round
-        newEgressDevice.addAll(recovered);
-        // Let's do a copy of the new egresses and filter out the source
-        Set<DeviceId> copyNewEgressDevice = ImmutableSet.copyOf(newEgressDevice);
-        newEgressDevice = newEgressDevice.stream()
-                .filter(deviceId -> !deviceId.equals(ingressDevice))
-                .collect(Collectors.toSet());
-
-        // Re-compute mcast tree for the the egress devices
-        Map<DeviceId, List<Path>> mcastTree = computeMcastTree(ingressDevice, newEgressDevice);
+        Set<DeviceId> newEgressDevices = sinksToBeAdded.stream()
+                .map(ConnectPoint::deviceId).collect(Collectors.toSet());
+        newEgressDevices.addAll(recovered);
+        Set<DeviceId> copyNewEgressDevices = ImmutableSet.copyOf(newEgressDevices);
+        newEgressDevices = newEgressDevices.stream()
+                .filter(deviceId -> !deviceId.equals(ingressDevice)).collect(Collectors.toSet());
+        Map<DeviceId, List<Path>> mcastTree = computeMcastTree(ingressDevice, newEgressDevices);
         // if the source was originally in the new locations, add new sinks
-        if (copyNewEgressDevice.contains(ingressDevice)) {
+        if (copyNewEgressDevices.contains(ingressDevice)) {
             sinksToBeAdded.stream()
                     .filter(connectPoint -> connectPoint.deviceId().equals(ingressDevice))
                     .forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, ImmutableList.of()));
         }
-
         // Construct a new path for each egress device
         mcastTree.forEach((egressDevice, paths) -> {
-            // We try to enforce the sinks path on the mcast tree
-            Optional<Path> mcastPath = getPath(ingressDevice, egressDevice,
-                                               mcastIp, paths);
-            // If a path is present, let's install it
+            Optional<Path> mcastPath = getPath(ingressDevice, egressDevice, mcastIp, paths, source);
             if (mcastPath.isPresent()) {
                 // Using recovery procedure
                 if (recovered.contains(egressDevice)) {
@@ -912,14 +927,8 @@
                             .filter(connectPoint -> connectPoint.deviceId().equals(egressDevice))
                             .forEach(sink -> processSinkAddedInternal(source, sink, mcastIp, paths));
                 }
-            } else {
-                // We were not able to find an alternative path for this egress
-                log.warn("Fail to recover egress device {} from {} failure",
-                         egressDevice, isLinkFailure ? "Link" : "Device");
-                removeGroupFromDevice(egressDevice, mcastIp, mcastUtils.assignedVlan(null));
             }
         });
-
     }
 
     /**
@@ -929,18 +938,18 @@
      * @param mcastIp the group address
      * @param prevsinks the previous sinks to be evaluated
      * @param newSinks the new sinks to be evaluted
+     * @param source the source connect point
      * @return the set of the sinks to be removed
      */
     private Set<ConnectPoint> processSinksToBeRemoved(IpAddress mcastIp,
                                                       Map<HostId, Set<ConnectPoint>> prevsinks,
-                                                      Map<HostId, Set<ConnectPoint>> newSinks) {
-        // Iterate over the sinks in order to build the set
-        // of the connect points to be removed from this group
+                                                      Map<HostId, Set<ConnectPoint>> newSinks,
+                                                      ConnectPoint source) {
         final Set<ConnectPoint> sinksToBeProcessed = Sets.newHashSet();
         prevsinks.forEach(((hostId, connectPoints) -> {
             // We have to check with the existing flows
             ConnectPoint sinkToBeProcessed = connectPoints.stream()
-                    .filter(connectPoint -> isSink(mcastIp, connectPoint))
+                    .filter(connectPoint -> isSinkForSource(mcastIp, connectPoint, source))
                     .findFirst().orElse(null);
             if (sinkToBeProcessed != null) {
                 // If the host has been removed or location has been removed
@@ -960,13 +969,13 @@
      *
      * @param newSinks the remaining sinks
      * @param prevSinks the previous sinks
+     * @param source the source connect point
      * @return the set of the sinks to be processed
      */
     private Set<ConnectPoint> processSinksToBeRecovered(IpAddress mcastIp,
                                                         Map<HostId, Set<ConnectPoint>> newSinks,
-                                                        Map<HostId, Set<ConnectPoint>> prevSinks) {
-        // Iterate over the sinks in order to build the set
-        // of the connect points to be served by this group
+                                                        Map<HostId, Set<ConnectPoint>> prevSinks,
+                                                        ConnectPoint source) {
         final Set<ConnectPoint> sinksToBeProcessed = Sets.newHashSet();
         newSinks.forEach((hostId, connectPoints) -> {
             // If it has more than 1 locations
@@ -979,7 +988,7 @@
             // Filter out if the remaining location is already served
             if (prevSinks.containsKey(hostId) && prevSinks.get(hostId).size() == 2) {
                 ConnectPoint sinkToBeProcessed = connectPoints.stream()
-                        .filter(connectPoint -> !isSink(mcastIp, connectPoint))
+                        .filter(connectPoint -> !isSinkForSource(mcastIp, connectPoint, source))
                         .findFirst().orElse(null);
                 if (sinkToBeProcessed != null) {
                     sinksToBeProcessed.add(sinkToBeProcessed);
@@ -1000,8 +1009,6 @@
      */
     private Set<ConnectPoint> processSinksToBeAdded(ConnectPoint source, IpAddress mcastIp,
                                                     Map<HostId, Set<ConnectPoint>> sinks) {
-        // Iterate over the sinks in order to build the set
-        // of the connect points to be served by this group
         final Set<ConnectPoint> sinksToBeProcessed = Sets.newHashSet();
         sinks.forEach(((hostId, connectPoints) -> {
             // If it has more than 2 locations
@@ -1012,23 +1019,44 @@
             }
             // If it has one location, just use it
             if (connectPoints.size() == 1) {
-                sinksToBeProcessed.add(connectPoints.stream()
-                                               .findFirst().orElse(null));
+                sinksToBeProcessed.add(connectPoints.stream().findFirst().orElse(null));
                 return;
             }
             // We prefer to reuse existing flows
             ConnectPoint sinkToBeProcessed = connectPoints.stream()
-                    .filter(connectPoint -> isSink(mcastIp, connectPoint))
-                    .findFirst().orElse(null);
+                    .filter(connectPoint -> {
+                        if (!isSinkForGroup(mcastIp, connectPoint, source)) {
+                            return false;
+                        }
+                        if (!isSinkReachable(mcastIp, connectPoint, source)) {
+                            return false;
+                        }
+                        ConnectPoint other = connectPoints.stream()
+                                .filter(remaining -> !remaining.equals(connectPoint))
+                                .findFirst().orElse(null);
+                        // We are already serving the sink
+                        return !isSinkForSource(mcastIp, other, source);
+                    }).findFirst().orElse(null);
+
             if (sinkToBeProcessed != null) {
                 sinksToBeProcessed.add(sinkToBeProcessed);
                 return;
             }
             // Otherwise we prefer to reuse existing egresses
-            Set<DeviceId> egresses = getDevice(mcastIp, EGRESS);
+            Set<DeviceId> egresses = getDevice(mcastIp, EGRESS, source);
             sinkToBeProcessed = connectPoints.stream()
-                    .filter(connectPoint -> egresses.contains(connectPoint.deviceId()))
-                    .findFirst().orElse(null);
+                    .filter(connectPoint -> {
+                        if (!egresses.contains(connectPoint.deviceId())) {
+                            return false;
+                        }
+                        if (!isSinkReachable(mcastIp, connectPoint, source)) {
+                            return false;
+                        }
+                        ConnectPoint other = connectPoints.stream()
+                                .filter(remaining -> !remaining.equals(connectPoint))
+                                .findFirst().orElse(null);
+                        return !isSinkForSource(mcastIp, other, source);
+                    }).findFirst().orElse(null);
             if (sinkToBeProcessed != null) {
                 sinksToBeProcessed.add(sinkToBeProcessed);
                 return;
@@ -1041,11 +1069,21 @@
                 sinksToBeProcessed.add(sinkToBeProcessed);
                 return;
             }
-            // Finally, we randomly pick a new location
-            sinksToBeProcessed.add(connectPoints.stream()
-                                           .findFirst().orElse(null));
+            // Finally, we randomly pick a new location if it is reachable
+            sinkToBeProcessed = connectPoints.stream()
+                    .filter(connectPoint -> {
+                        if (!isSinkReachable(mcastIp, connectPoint, source)) {
+                            return false;
+                        }
+                        ConnectPoint other = connectPoints.stream()
+                                .filter(remaining -> !remaining.equals(connectPoint))
+                                .findFirst().orElse(null);
+                        return !isSinkForSource(mcastIp, other, source);
+                    }).findFirst().orElse(null);
+            if (sinkToBeProcessed != null) {
+                sinksToBeProcessed.add(sinkToBeProcessed);
+            }
         }));
-        // We have done, return the set
         return sinksToBeProcessed;
     }
 
@@ -1053,21 +1091,34 @@
      * Utility method to remove all the ingress transit ports.
      *
      * @param mcastIp the group ip
-     * @param ingressDevice the ingress device for this group
-     * @param source the source connect point
+     * @param ingressDevices the ingress devices
+     * @param sources the source connect points
      */
-    private void removeIngressTransitPorts(IpAddress mcastIp, DeviceId ingressDevice,
-                                           ConnectPoint source) {
-        Set<PortNumber> ingressTransitPorts = ingressTransitPort(mcastIp);
-        ingressTransitPorts.forEach(ingressTransitPort -> {
-            if (ingressTransitPort != null) {
-                boolean isLast = removePortFromDevice(ingressDevice, ingressTransitPort,
-                                                      mcastIp, mcastUtils.assignedVlan(source));
-                if (isLast) {
-                    mcastRoleStore.remove(new McastStoreKey(mcastIp, ingressDevice));
-                }
+    private void removeIngressTransitPorts(IpAddress mcastIp, Set<DeviceId> ingressDevices,
+                                           Set<ConnectPoint> sources) {
+        Map<ConnectPoint, Set<PortNumber>> ingressTransitPorts = Maps.newHashMap();
+        sources.forEach(source -> {
+            DeviceId ingressDevice = ingressDevices.stream()
+                    .filter(deviceId -> deviceId.equals(source.deviceId()))
+                    .findFirst().orElse(null);
+            if (ingressDevice == null) {
+                log.warn("Skip removeIngressTransitPorts - " +
+                                 "Missing ingress for source {} and group {}",
+                         source, mcastIp);
+                return;
             }
+            ingressTransitPorts.put(source, ingressTransitPort(mcastIp, ingressDevice, source));
         });
+        ingressTransitPorts.forEach((source, ports) -> ports.forEach(ingressTransitPort -> {
+            DeviceId ingressDevice = ingressDevices.stream()
+                    .filter(deviceId -> deviceId.equals(source.deviceId()))
+                    .findFirst().orElse(null);
+            boolean isLast = removePortFromDevice(ingressDevice, ingressTransitPort,
+                                                  mcastIp, mcastUtils.assignedVlan(source));
+            if (isLast) {
+                mcastRoleStore.remove(new McastRoleStoreKey(mcastIp, ingressDevice, source));
+            }
+        }));
     }
 
     /**
@@ -1081,7 +1132,7 @@
      */
     private void addPortToDevice(DeviceId deviceId, PortNumber port,
                                  IpAddress mcastIp, VlanId assignedVlan) {
-        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, deviceId);
+        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, deviceId, assignedVlan);
         ImmutableSet.Builder<PortNumber> portBuilder = ImmutableSet.builder();
         NextObjective newNextObj;
         if (!mcastNextObjStore.containsKey(mcastStoreKey)) {
@@ -1118,8 +1169,7 @@
         ObjectiveContext context = new DefaultObjectiveContext(
                 (objective) -> log.debug("Successfully add {} on {}/{}, vlan {}",
                         mcastIp, deviceId, port.toLong(), assignedVlan),
-                (objective, error) ->
-                        log.warn("Failed to add {} on {}/{}, vlan {}: {}",
+                (objective, error) -> log.warn("Failed to add {} on {}/{}, vlan {}: {}",
                                 mcastIp, deviceId, port.toLong(), assignedVlan, error));
         ForwardingObjective fwdObj = mcastUtils.fwdObjBuilder(mcastIp, assignedVlan,
                                                               newNextObj.id()).add(context);
@@ -1141,37 +1191,32 @@
     private boolean removePortFromDevice(DeviceId deviceId, PortNumber port,
                                          IpAddress mcastIp, VlanId assignedVlan) {
         McastStoreKey mcastStoreKey =
-                new McastStoreKey(mcastIp, deviceId);
+                new McastStoreKey(mcastIp, deviceId, assignedVlan);
         // This device is not serving this multicast group
         if (!mcastNextObjStore.containsKey(mcastStoreKey)) {
-            log.warn("{} is not serving {} on port {}. Abort.", deviceId, mcastIp, port);
-            return false;
+            return true;
         }
         NextObjective nextObj = mcastNextObjStore.get(mcastStoreKey).value();
-
         Set<PortNumber> existingPorts = mcastUtils.getPorts(nextObj.next());
         // This port does not serve this multicast group
         if (!existingPorts.contains(port)) {
-            log.warn("{} is not serving {} on port {}. Abort.", deviceId, mcastIp, port);
-            return false;
+            if (!existingPorts.isEmpty()) {
+                log.warn("{} is not serving {} on port {}. Abort.", deviceId, mcastIp, port);
+                return false;
+            }
+            return true;
         }
         // Copy and modify the ImmutableSet
         existingPorts = Sets.newHashSet(existingPorts);
         existingPorts.remove(port);
-
         NextObjective newNextObj;
         ObjectiveContext context;
         ForwardingObjective fwdObj;
         if (existingPorts.isEmpty()) {
-            // If this is the last sink, remove flows and last bucket
-            // NOTE: Rely on GroupStore garbage collection rather than explicitly
-            //       remove L3MG since there might be other flows/groups refer to
-            //       the same L2IG
             context = new DefaultObjectiveContext(
                     (objective) -> log.debug("Successfully remove {} on {}/{}, vlan {}",
                             mcastIp, deviceId, port.toLong(), assignedVlan),
-                    (objective, error) ->
-                            log.warn("Failed to remove {} on {}/{}, vlan {}: {}",
+                    (objective, error) -> log.warn("Failed to remove {} on {}/{}, vlan {}: {}",
                                     mcastIp, deviceId, port.toLong(), assignedVlan, error));
             fwdObj = mcastUtils.fwdObjBuilder(mcastIp, assignedVlan, nextObj.id()).remove(context);
             mcastNextObjStore.remove(mcastStoreKey);
@@ -1180,8 +1225,7 @@
             context = new DefaultObjectiveContext(
                     (objective) -> log.debug("Successfully update {} on {}/{}, vlan {}",
                             mcastIp, deviceId, port.toLong(), assignedVlan),
-                    (objective, error) ->
-                            log.warn("Failed to update {} on {}/{}, vlan {}: {}",
+                    (objective, error) -> log.warn("Failed to update {} on {}/{}, vlan {}: {}",
                                     mcastIp, deviceId, port.toLong(), assignedVlan, error));
             // Here we store the next objective with the remaining port
             newNextObj = mcastUtils.nextObjBuilder(mcastIp, assignedVlan,
@@ -1206,36 +1250,28 @@
      */
     private void removeGroupFromDevice(DeviceId deviceId, IpAddress mcastIp,
                                        VlanId assignedVlan) {
-        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, deviceId);
+        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, deviceId, assignedVlan);
         // This device is not serving this multicast group
         if (!mcastNextObjStore.containsKey(mcastStoreKey)) {
             log.warn("{} is not serving {}. Abort.", deviceId, mcastIp);
             return;
         }
         NextObjective nextObj = mcastNextObjStore.get(mcastStoreKey).value();
-        // NOTE: Rely on GroupStore garbage collection rather than explicitly
-        //       remove L3MG since there might be other flows/groups refer to
-        //       the same L2IG
         ObjectiveContext context = new DefaultObjectiveContext(
                 (objective) -> log.debug("Successfully remove {} on {}, vlan {}",
                         mcastIp, deviceId, assignedVlan),
-                (objective, error) ->
-                        log.warn("Failed to remove {} on {}, vlan {}: {}",
+                (objective, error) -> log.warn("Failed to remove {} on {}, vlan {}: {}",
                                 mcastIp, deviceId, assignedVlan, error));
         ForwardingObjective fwdObj = mcastUtils.fwdObjBuilder(mcastIp, assignedVlan, nextObj.id()).remove(context);
         srManager.flowObjectiveService.forward(deviceId, fwdObj);
         mcastNextObjStore.remove(mcastStoreKey);
-        mcastRoleStore.remove(mcastStoreKey);
     }
 
     private void installPath(IpAddress mcastIp, ConnectPoint source, Path mcastPath) {
-        // Get Links
         List<Link> links = mcastPath.links();
-
         // Setup new ingress mcast role
-        mcastRoleStore.put(new McastStoreKey(mcastIp, links.get(0).src().deviceId()),
+        mcastRoleStore.put(new McastRoleStoreKey(mcastIp, links.get(0).src().deviceId(), source),
                            INGRESS);
-
         // For each link, modify the next on the source device adding the src port
         // and a new filter objective on the destination port
         links.forEach(link -> {
@@ -1244,11 +1280,10 @@
             mcastUtils.addFilterToDevice(link.dst().deviceId(), link.dst().port(),
                               mcastUtils.assignedVlan(null), mcastIp, null);
         });
-
         // Setup mcast role for the transit
         links.stream()
                 .filter(link -> !link.src().deviceId().equals(source.deviceId()))
-                .forEach(link -> mcastRoleStore.put(new McastStoreKey(mcastIp, link.src().deviceId()),
+                .forEach(link -> mcastRoleStore.put(new McastRoleStoreKey(mcastIp, link.src().deviceId(), source),
                                                     TRANSIT));
     }
 
@@ -1262,10 +1297,8 @@
      */
     private Set<Link> exploreMcastTree(Set<DeviceId> egresses,
                                        Map<DeviceId, List<Path>> availablePaths) {
-        // Length of the shortest path
         int minLength = Integer.MAX_VALUE;
         int length;
-        // Current paths
         List<Path> currentPaths;
         // Verify the source can still reach all the egresses
         for (DeviceId egress : egresses) {
@@ -1275,8 +1308,7 @@
             if (currentPaths.isEmpty()) {
                 continue;
             }
-            // Get the length of the first one available,
-            // update the min length
+            // Get the length of the first one available, update the min length
             length = currentPaths.get(0).links().size();
             if (length < minLength) {
                 minLength = length;
@@ -1286,9 +1318,7 @@
         if (minLength == Integer.MAX_VALUE) {
             return Collections.emptySet();
         }
-        // Iterate looking for shared links
         int index = 0;
-        // Define the sets for the intersection
         Set<Link> sharedLinks = Sets.newHashSet();
         Set<Link> currentSharedLinks;
         Set<Link> currentLinks;
@@ -1296,11 +1326,7 @@
         // Let's find out the shared links
         while (index < minLength) {
             // Initialize the intersection with the paths related to the first egress
-            currentPaths = availablePaths.get(
-                    egresses.stream()
-                            .findFirst()
-                            .orElse(null)
-            );
+            currentPaths = availablePaths.get(egresses.stream().findFirst().orElse(null));
             currentSharedLinks = Sets.newHashSet();
             // Iterate over the paths and take the "index" links
             for (Path path : currentPaths) {
@@ -1326,9 +1352,8 @@
             sharedLinks.addAll(currentSharedLinks);
             index++;
         }
-        // If the shared links is empty and there are egress
-        // let's retry another time with less sinks, we can
-        // still build optimal subtrees
+        // If the shared links is empty and there are egress let's retry another time with less sinks,
+        // we can still build optimal subtrees
         if (sharedLinks.isEmpty() && egresses.size() > 1 && egressToRemove != null) {
             egresses.remove(egressToRemove);
             sharedLinks = exploreMcastTree(egresses, availablePaths);
@@ -1346,12 +1371,9 @@
     private Map<ConnectPoint, List<Path>> computeSinkMcastTree(DeviceId source,
                                                                Set<ConnectPoint> sinks) {
         // Get the egress devices, remove source from the egress if present
-        Set<DeviceId> egresses = sinks.stream()
-                .map(ConnectPoint::deviceId)
-                .filter(deviceId -> !deviceId.equals(source))
-                .collect(Collectors.toSet());
+        Set<DeviceId> egresses = sinks.stream().map(ConnectPoint::deviceId)
+                .filter(deviceId -> !deviceId.equals(source)).collect(Collectors.toSet());
         Map<DeviceId, List<Path>> mcastTree = computeMcastTree(source, egresses);
-        // Build final tree and return it as it is
         final Map<ConnectPoint, List<Path>> finalTree = Maps.newHashMap();
         // We need to put back the source if it was originally present
         sinks.forEach(sink -> {
@@ -1372,14 +1394,12 @@
                                                        Set<DeviceId> egresses) {
         // Pre-compute all the paths
         Map<DeviceId, List<Path>> availablePaths = Maps.newHashMap();
-        // No links to enforce
         egresses.forEach(egress -> availablePaths.put(egress, getPaths(source, egress,
                                                                        Collections.emptySet())));
         // Explore the topology looking for shared links amongst the egresses
         Set<Link> linksToEnforce = exploreMcastTree(Sets.newHashSet(egresses), availablePaths);
-        // Remove all the paths from the previous computation
-        availablePaths.clear();
         // Build the final paths enforcing the shared links between egress devices
+        availablePaths.clear();
         egresses.forEach(egress -> availablePaths.put(egress, getPaths(source, egress,
                                                                        linksToEnforce)));
         return availablePaths;
@@ -1393,16 +1413,9 @@
      * @return list of paths from src to dst
      */
     private List<Path> getPaths(DeviceId src, DeviceId dst, Set<Link> linksToEnforce) {
-        // Takes a snapshot of the topology
         final Topology currentTopology = topologyService.currentTopology();
-        // Build a specific link weigher for this path computation
         final LinkWeigher linkWeigher = new SRLinkWeigher(srManager, src, linksToEnforce);
-        // We will use our custom link weigher for our path
-        // computations and build the list of valid paths
-        List<Path> allPaths = Lists.newArrayList(
-                topologyService.getPaths(currentTopology, src, dst, linkWeigher)
-        );
-        // If there are no valid paths, just exit
+        List<Path> allPaths = Lists.newArrayList(topologyService.getPaths(currentTopology, src, dst, linkWeigher));
         log.debug("{} path(s) found from {} to {}", allPaths.size(), src, dst);
         return allPaths;
     }
@@ -1418,55 +1431,44 @@
      * @param allPaths paths list
      * @return an optional path from src to dst
      */
-    private Optional<Path> getPath(DeviceId src, DeviceId dst,
-                                   IpAddress mcastIp, List<Path> allPaths) {
-        // Firstly we get all the valid paths, if the supplied are null
+    private Optional<Path> getPath(DeviceId src, DeviceId dst, IpAddress mcastIp,
+                                   List<Path> allPaths, ConnectPoint source) {
         if (allPaths == null) {
             allPaths = getPaths(src, dst, Collections.emptySet());
         }
-
-        // If there are no paths just exit
         if (allPaths.isEmpty()) {
             return Optional.empty();
         }
-
-        // Create a map index of suitablity-to-list of paths. For example
+        // Create a map index of suitability-to-list of paths. For example
         // a path in the list associated to the index 1 shares only the
         // first hop and it is less suitable of a path belonging to the index
         // 2 that shares leaf-spine.
         Map<Integer, List<Path>> eligiblePaths = Maps.newHashMap();
-        // Some init steps
         int nhop;
         McastStoreKey mcastStoreKey;
-        Link hop;
         PortNumber srcPort;
         Set<PortNumber> existingPorts;
         NextObjective nextObj;
-        // Iterate over paths looking for eligible paths
         for (Path path : allPaths) {
-            // Unlikely, it will happen...
             if (!src.equals(path.links().get(0).src().deviceId())) {
                 continue;
             }
             nhop = 0;
             // Iterate over the links
-            while (nhop < path.links().size()) {
-                // Get the link and verify if a next related
-                // to the src device exist in the store
-                hop = path.links().get(nhop);
-                mcastStoreKey = new McastStoreKey(mcastIp, hop.src().deviceId());
-                // It does not exist in the store, exit
+            for (Link hop : path.links()) {
+                VlanId assignedVlan = mcastUtils.assignedVlan(hop.src().deviceId().equals(src) ?
+                                                                      source : null);
+                mcastStoreKey = new McastStoreKey(mcastIp, hop.src().deviceId(), assignedVlan);
+                // It does not exist in the store, go to the next link
                 if (!mcastNextObjStore.containsKey(mcastStoreKey)) {
-                    break;
+                    continue;
                 }
-                // Get the output ports on the next
                 nextObj = mcastNextObjStore.get(mcastStoreKey).value();
                 existingPorts = mcastUtils.getPorts(nextObj.next());
-                // And the src port on the link
                 srcPort = hop.src().port();
-                // the src port is not used as output, exit
+                // the src port is not used as output, go to the next link
                 if (!existingPorts.contains(srcPort)) {
-                    break;
+                    continue;
                 }
                 nhop++;
             }
@@ -1479,29 +1481,38 @@
                 });
             }
         }
-
-        // No suitable paths
         if (eligiblePaths.isEmpty()) {
             log.debug("No eligiblePath(s) found from {} to {}", src, dst);
-            // Otherwise, randomly pick a path
             Collections.shuffle(allPaths);
             return allPaths.stream().findFirst();
         }
-
         // Let's take the best ones
-        Integer bestIndex = eligiblePaths.keySet()
-                .stream()
-                .sorted(Comparator.reverseOrder())
-                .findFirst().orElse(null);
+        Integer bestIndex = eligiblePaths.keySet().stream()
+                .sorted(Comparator.reverseOrder()).findFirst().orElse(null);
         List<Path> bestPaths = eligiblePaths.get(bestIndex);
         log.debug("{} eligiblePath(s) found from {} to {}",
                   bestPaths.size(), src, dst);
-        // randomly pick a path on the highest index
         Collections.shuffle(bestPaths);
         return bestPaths.stream().findFirst();
     }
 
     /**
+     * Gets device(s) of given role and of given source in given multicast tree.
+     *
+     * @param mcastIp multicast IP
+     * @param role multicast role
+     * @param source source connect point
+     * @return set of device ID or empty set if not found
+     */
+    private Set<DeviceId> getDevice(IpAddress mcastIp, McastRole role, ConnectPoint source) {
+        return mcastRoleStore.entrySet().stream()
+                .filter(entry -> entry.getKey().mcastIp().equals(mcastIp) &&
+                        entry.getKey().source().equals(source) &&
+                        entry.getValue().value() == role)
+                .map(Entry::getKey).map(McastRoleStoreKey::deviceId).collect(Collectors.toSet());
+    }
+
+    /**
      * Gets device(s) of given role in given multicast group.
      *
      * @param mcastIp multicast IP
@@ -1512,8 +1523,34 @@
         return mcastRoleStore.entrySet().stream()
                 .filter(entry -> entry.getKey().mcastIp().equals(mcastIp) &&
                         entry.getValue().value() == role)
-                .map(Entry::getKey).map(McastStoreKey::deviceId)
-                .collect(Collectors.toSet());
+                .map(Entry::getKey).map(McastRoleStoreKey::deviceId).collect(Collectors.toSet());
+    }
+
+    /**
+     * Gets source(s) of given role, given device in given multicast group.
+     *
+     * @param mcastIp multicast IP
+     * @param deviceId device id
+     * @param role multicast role
+     * @return set of device ID or empty set if not found
+     */
+    private Set<ConnectPoint> getSources(IpAddress mcastIp, DeviceId deviceId, McastRole role) {
+        return mcastRoleStore.entrySet().stream()
+                .filter(entry -> entry.getKey().mcastIp().equals(mcastIp) &&
+                        entry.getKey().deviceId().equals(deviceId) && entry.getValue().value() == role)
+                .map(Entry::getKey).map(McastRoleStoreKey::source).collect(Collectors.toSet());
+    }
+
+    /**
+     * Gets source(s) of given multicast group.
+     *
+     * @param mcastIp multicast IP
+     * @return set of device ID or empty set if not found
+     */
+    private Set<ConnectPoint> getSources(IpAddress mcastIp) {
+        return mcastRoleStore.entrySet().stream()
+                .filter(entry -> entry.getKey().mcastIp().equals(mcastIp))
+                .map(Entry::getKey).map(McastRoleStoreKey::source).collect(Collectors.toSet());
     }
 
     /**
@@ -1527,9 +1564,8 @@
         PortNumber port = link.src().port();
         return mcastNextObjStore.entrySet().stream()
                 .filter(entry -> entry.getKey().deviceId().equals(deviceId) &&
-                        mcastUtils.getPorts(entry.getValue().value().next()).contains(port))
-                .map(Entry::getKey).map(McastStoreKey::mcastIp)
-                .collect(Collectors.toSet());
+                    mcastUtils.getPorts(entry.getValue().value().next()).contains(port))
+                .map(Entry::getKey).map(McastStoreKey::mcastIp).collect(Collectors.toSet());
     }
 
     /**
@@ -1549,15 +1585,16 @@
      * Gets the spine-facing port on ingress device of given multicast group.
      *
      * @param mcastIp multicast IP
+     * @param ingressDevice the ingress device
+     * @param source the source connect point
      * @return spine-facing port on ingress device
      */
-    private Set<PortNumber> ingressTransitPort(IpAddress mcastIp) {
-        DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
-                .stream().findAny().orElse(null);
+    private Set<PortNumber> ingressTransitPort(IpAddress mcastIp, DeviceId ingressDevice,
+                                               ConnectPoint source) {
         ImmutableSet.Builder<PortNumber> portBuilder = ImmutableSet.builder();
         if (ingressDevice != null) {
-            NextObjective nextObj = mcastNextObjStore
-                    .get(new McastStoreKey(mcastIp, ingressDevice)).value();
+            NextObjective nextObj = mcastNextObjStore.get(new McastStoreKey(mcastIp, ingressDevice,
+                                                            mcastUtils.assignedVlan(source))).value();
             Set<PortNumber> ports = mcastUtils.getPorts(nextObj.next());
             // Let's find out all the ingress-transit ports
             for (PortNumber port : ports) {
@@ -1573,58 +1610,64 @@
     }
 
     /**
-     * Verify if the given device has sinks
-     * for the multicast group.
-     *
-     * @param deviceId device Id
-     * @param mcastIp multicast IP
-     * @return true if the device has sink for the group.
-     * False otherwise.
-     */
-    private boolean hasSinks(DeviceId deviceId, IpAddress mcastIp) {
-        if (deviceId != null) {
-            // Get the nextobjective
-            Versioned<NextObjective> versionedNextObj = mcastNextObjStore.get(
-                    new McastStoreKey(mcastIp, deviceId)
-            );
-            // If it exists
-            if (versionedNextObj != null) {
-                NextObjective nextObj = versionedNextObj.value();
-                // Retrieves all the output ports
-                Set<PortNumber> ports = mcastUtils.getPorts(nextObj.next());
-                // Tries to find at least one port that is not spine-facing
-                for (PortNumber port : ports) {
-                    // Spine-facing port should have no subnet and no xconnect
-                    if (srManager.deviceConfiguration() != null &&
-                            (!srManager.deviceConfiguration().getPortSubnets(deviceId, port).isEmpty() ||
-                            srManager.xConnectHandler.hasXConnect(new ConnectPoint(deviceId, port)))) {
-                        return true;
-                    }
-                }
-            }
-        }
-        return false;
-    }
-
-    /**
      * Verify if a given connect point is sink for this group.
      *
      * @param mcastIp group address
      * @param connectPoint connect point to be verified
+     * @param source source connect point
      * @return true if the connect point is sink of the group
      */
-    private boolean isSink(IpAddress mcastIp, ConnectPoint connectPoint) {
-        // Let's check if we are already serving that location
-        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, connectPoint.deviceId());
+    private boolean isSinkForGroup(IpAddress mcastIp, ConnectPoint connectPoint,
+                                   ConnectPoint source) {
+        VlanId assignedVlan = mcastUtils.assignedVlan(connectPoint.deviceId().equals(source.deviceId()) ?
+                                                              source : null);
+        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, connectPoint.deviceId(), assignedVlan);
         if (!mcastNextObjStore.containsKey(mcastStoreKey)) {
             return false;
         }
-        // Get next and check with the port
         NextObjective mcastNext = mcastNextObjStore.get(mcastStoreKey).value();
         return mcastUtils.getPorts(mcastNext.next()).contains(connectPoint.port());
     }
 
     /**
+     * Verify if a given connect point is sink for this group and for this source.
+     *
+     * @param mcastIp group address
+     * @param connectPoint connect point to be verified
+     * @param source source connect point
+     * @return true if the connect point is sink of the group
+     */
+    private boolean isSinkForSource(IpAddress mcastIp, ConnectPoint connectPoint,
+                                    ConnectPoint source) {
+        boolean isSink = isSinkForGroup(mcastIp, connectPoint, source);
+        DeviceId device;
+        if (connectPoint.deviceId().equals(source.deviceId())) {
+            device = getDevice(mcastIp, INGRESS, source).stream()
+                    .filter(deviceId -> deviceId.equals(connectPoint.deviceId()))
+                    .findFirst().orElse(null);
+        } else {
+            device = getDevice(mcastIp, EGRESS, source).stream()
+                    .filter(deviceId -> deviceId.equals(connectPoint.deviceId()))
+                    .findFirst().orElse(null);
+        }
+        return isSink && device != null;
+    }
+
+    /**
+     * Verify if a sink is reachable from this source.
+     *
+     * @param mcastIp group address
+     * @param sink connect point to be verified
+     * @param source source connect point
+     * @return true if the connect point is reachable from the source
+     */
+    private boolean isSinkReachable(IpAddress mcastIp, ConnectPoint sink,
+                                    ConnectPoint source) {
+        return sink.deviceId().equals(source.deviceId()) ||
+                getPath(source.deviceId(), sink.deviceId(), mcastIp, null, source).isPresent();
+    }
+
+    /**
      * Updates filtering objective for given device and port.
      * It is called in general when the mcast config has been
      * changed.
@@ -1639,26 +1682,24 @@
         lastMcastChange = Instant.now();
         mcastLock();
         try {
-            // Iterates over the route and updates properly the filtering objective
-            // on the source device.
+            // Iterates over the route and updates properly the filtering objective on the source device.
             srManager.multicastRouteService.getRoutes().forEach(mcastRoute -> {
                 log.debug("Update filter for {}", mcastRoute.group());
-                // Verify leadership on the operation
                 if (!mcastUtils.isLeader(mcastRoute.group())) {
                     log.debug("Skip {} due to lack of leadership", mcastRoute.group());
                     return;
                 }
-                // FIXME To be addressed with multiple sources support
-                ConnectPoint source = srManager.multicastRouteService.sources(mcastRoute)
-                        .stream()
-                        .findFirst().orElse(null);
-                if (source.deviceId().equals(deviceId) && source.port().equals(portNum)) {
-                    if (install) {
-                        mcastUtils.addFilterToDevice(deviceId, portNum, vlanId, mcastRoute.group(), INGRESS);
-                    } else {
-                        mcastUtils.removeFilterToDevice(deviceId, portNum, vlanId, mcastRoute.group(), null);
+                // Get the sources and for each one update properly the filtering objectives
+                Set<ConnectPoint> sources = srManager.multicastRouteService.sources(mcastRoute);
+                sources.forEach(source -> {
+                    if (source.deviceId().equals(deviceId) && source.port().equals(portNum)) {
+                        if (install) {
+                            mcastUtils.addFilterToDevice(deviceId, portNum, vlanId, mcastRoute.group(), INGRESS);
+                        } else {
+                            mcastUtils.removeFilterToDevice(deviceId, portNum, vlanId, mcastRoute.group(), null);
+                        }
                     }
-                }
+                });
             });
         } finally {
             mcastUnlock();
@@ -1671,111 +1712,105 @@
      * Verification consists in creating new nexts with VERIFY operation. Actually,
      * the operation is totally delegated to the driver.
      */
-     private final class McastBucketCorrector implements Runnable {
+    private final class McastBucketCorrector implements Runnable {
 
         @Override
         public void run() {
-            // Verify if the Mcast has been stable for MCAST_STABLITY_THRESHOLD
             if (!isMcastStable()) {
                 return;
             }
-            // Acquires lock
             mcastLock();
             try {
                 // Iterates over the routes and verify the related next objectives
                 srManager.multicastRouteService.getRoutes()
-                    .stream()
-                    .map(McastRoute::group)
+                    .stream().map(McastRoute::group)
                     .forEach(mcastIp -> {
-                        log.trace("Running mcast buckets corrector for mcast group: {}",
-                                  mcastIp);
-
-                        // For each group we get current information in the store
-                        // and issue a check of the next objectives in place
-                        DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
-                                .stream().findAny().orElse(null);
-                        Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
-                        Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS);
-                        // Get source and sinks from Mcast Route Service and warn about errors
-                        ConnectPoint source = mcastUtils.getSource(mcastIp);
+                        log.trace("Running mcast buckets corrector for mcast group: {}", mcastIp);
+                        // Verify leadership on the operation
+                        if (!mcastUtils.isLeader(mcastIp)) {
+                            log.trace("Skip {} due to lack of leadership", mcastIp);
+                            return;
+                        }
+                        // Get sources and sinks from Mcast Route Service and warn about errors
+                        Set<ConnectPoint> sources = mcastUtils.getSources(mcastIp);
                         Set<ConnectPoint> sinks = mcastUtils.getSinks(mcastIp).values().stream()
-                                .flatMap(Collection::stream)
-                                .collect(Collectors.toSet());
-
-                        // Do not proceed if ingress device or source of this group are missing
-                        if (ingressDevice == null || source == null) {
+                                .flatMap(Collection::stream).collect(Collectors.toSet());
+                        // Do not proceed if sources of this group are missing
+                        if (sources.isEmpty()) {
                             if (!sinks.isEmpty()) {
                                 log.warn("Unable to run buckets corrector. " +
-                                                 "Missing ingress {} or source {} for group {}",
-                                         ingressDevice, source, mcastIp);
+                                                 "Missing source {} for group {}", sources, mcastIp);
                             }
                             return;
                         }
-
-                        // Continue only when this instance is the leader of the group
-                        if (!mcastUtils.isLeader(mcastIp)) {
-                            log.trace("Unable to run buckets corrector. " +
-                                             "Skip {} due to lack of leadership", mcastIp);
-                            return;
-                        }
-
-                        // Create the set of the devices to be processed
-                        ImmutableSet.Builder<DeviceId> devicesBuilder = ImmutableSet.builder();
-                        devicesBuilder.add(ingressDevice);
-                        if (!transitDevices.isEmpty()) {
-                            devicesBuilder.addAll(transitDevices);
-                        }
-                        if (!egressDevices.isEmpty()) {
-                            devicesBuilder.addAll(egressDevices);
-                        }
-                        Set<DeviceId> devicesToProcess = devicesBuilder.build();
-
-                        // Iterate over the devices
-                        devicesToProcess.forEach(deviceId -> {
-                            McastStoreKey currentKey = new McastStoreKey(mcastIp, deviceId);
-                            // If next exists in our store verify related next objective
-                            if (mcastNextObjStore.containsKey(currentKey)) {
-                                NextObjective currentNext = mcastNextObjStore.get(currentKey).value();
-                                // Get current ports
-                                Set<PortNumber> currentPorts = mcastUtils.getPorts(currentNext.next());
-                                // Rebuild the next objective
-                                currentNext = mcastUtils.nextObjBuilder(
-                                        mcastIp,
-                                        mcastUtils.assignedVlan(deviceId.equals(source.deviceId()) ?
-                                                                        source : null),
-                                        currentPorts,
-                                        currentNext.id()
-                                ).verify();
-                                // Send to the flowobjective service
-                                srManager.flowObjectiveService.next(deviceId, currentNext);
-                            } else {
-                                log.warn("Unable to run buckets corrector. " +
-                                                 "Missing next for {} and group {}",
-                                         deviceId, mcastIp);
+                        sources.forEach(source -> {
+                            // For each group we get current information in the store
+                            // and issue a check of the next objectives in place
+                            Set<DeviceId> ingressDevices = getDevice(mcastIp, INGRESS, source);
+                            Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT, source);
+                            Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS, source);
+                            // Do not proceed if ingress devices are missing
+                            if (ingressDevices.isEmpty()) {
+                                if (!sinks.isEmpty()) {
+                                    log.warn("Unable to run buckets corrector. " +
+                                                 "Missing ingress {} for source {} and for group {}",
+                                             ingressDevices, source, mcastIp);
+                                }
+                                return;
                             }
+                            // Create the set of the devices to be processed
+                            ImmutableSet.Builder<DeviceId> devicesBuilder = ImmutableSet.builder();
+                            if (!ingressDevices.isEmpty()) {
+                                devicesBuilder.addAll(ingressDevices);
+                            }
+                            if (!transitDevices.isEmpty()) {
+                                devicesBuilder.addAll(transitDevices);
+                            }
+                            if (!egressDevices.isEmpty()) {
+                                devicesBuilder.addAll(egressDevices);
+                            }
+                            Set<DeviceId> devicesToProcess = devicesBuilder.build();
+                            devicesToProcess.forEach(deviceId -> {
+                                VlanId assignedVlan = mcastUtils.assignedVlan(deviceId.equals(source.deviceId()) ?
+                                                                                      source : null);
+                                McastStoreKey currentKey = new McastStoreKey(mcastIp, deviceId, assignedVlan);
+                                if (mcastNextObjStore.containsKey(currentKey)) {
+                                    NextObjective currentNext = mcastNextObjStore.get(currentKey).value();
+                                    // Rebuild the next objective using assigned vlan
+                                    currentNext = mcastUtils.nextObjBuilder(mcastIp, assignedVlan,
+                                                mcastUtils.getPorts(currentNext.next()), currentNext.id()).verify();
+                                    // Send to the flowobjective service
+                                    srManager.flowObjectiveService.next(deviceId, currentNext);
+                                } else {
+                                    log.warn("Unable to run buckets corrector. " +
+                                             "Missing next for {}, for source {} and for group {}",
+                                             deviceId, source, mcastIp);
+                                }
+                            });
                         });
-
                     });
             } finally {
-                // Finally, it releases the lock
                 mcastUnlock();
             }
 
         }
     }
 
+    /**
+     * Returns the associated next ids to the mcast groups or to the single
+     * group if mcastIp is present.
+     *
+     * @param mcastIp the group ip
+     * @return the mapping mcastIp-device to next id
+     */
     public Map<McastStoreKey, Integer> getMcastNextIds(IpAddress mcastIp) {
-        // If mcast ip is present
         if (mcastIp != null) {
             return mcastNextObjStore.entrySet().stream()
                     .filter(mcastEntry -> mcastIp.equals(mcastEntry.getKey().mcastIp()))
-                    .collect(Collectors.toMap(Entry::getKey,
-                                              entry -> entry.getValue().value().id()));
+                    .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().value().id()));
         }
-        // Otherwise take all the groups
         return mcastNextObjStore.entrySet().stream()
-                .collect(Collectors.toMap(Entry::getKey,
-                                          entry -> entry.getValue().value().id()));
+                .collect(Collectors.toMap(Entry::getKey, entry -> entry.getValue().value().id()));
     }
 
     /**
@@ -1789,20 +1824,46 @@
      */
     @Deprecated
     public Map<McastStoreKey, McastRole> getMcastRoles(IpAddress mcastIp) {
-        // If mcast ip is present
         if (mcastIp != null) {
             return mcastRoleStore.entrySet().stream()
                     .filter(mcastEntry -> mcastIp.equals(mcastEntry.getKey().mcastIp()))
-                    .collect(Collectors.toMap(Entry::getKey,
-                                              entry -> entry.getValue().value()));
+                    .collect(Collectors.toMap(entry -> new McastStoreKey(entry.getKey().mcastIp(),
+                     entry.getKey().deviceId(), null), entry -> entry.getValue().value()));
         }
-        // Otherwise take all the groups
         return mcastRoleStore.entrySet().stream()
-                .collect(Collectors.toMap(Entry::getKey,
-                                          entry -> entry.getValue().value()));
+                .collect(Collectors.toMap(entry -> new McastStoreKey(entry.getKey().mcastIp(),
+                 entry.getKey().deviceId(), null), entry -> entry.getValue().value()));
     }
 
     /**
+     * Returns the associated roles to the mcast groups.
+     *
+     * @param mcastIp the group ip
+     * @param sourcecp the source connect point
+     * @return the mapping mcastIp-device to mcast role
+     */
+    public Map<McastRoleStoreKey, McastRole> getMcastRoles(IpAddress mcastIp,
+                                                       ConnectPoint sourcecp) {
+        if (mcastIp != null) {
+            Map<McastRoleStoreKey, McastRole> roles = mcastRoleStore.entrySet().stream()
+                    .filter(mcastEntry -> mcastIp.equals(mcastEntry.getKey().mcastIp()))
+                    .collect(Collectors.toMap(entry -> new McastRoleStoreKey(entry.getKey().mcastIp(),
+                     entry.getKey().deviceId(), entry.getKey().source()), entry -> entry.getValue().value()));
+            if (sourcecp != null) {
+                roles = roles.entrySet().stream()
+                        .filter(mcastEntry -> sourcecp.equals(mcastEntry.getKey().source()))
+                        .collect(Collectors.toMap(entry -> new McastRoleStoreKey(entry.getKey().mcastIp(),
+                         entry.getKey().deviceId(), entry.getKey().source()), Entry::getValue));
+            }
+            return roles;
+        }
+        return mcastRoleStore.entrySet().stream()
+                .collect(Collectors.toMap(entry -> new McastRoleStoreKey(entry.getKey().mcastIp(),
+                 entry.getKey().deviceId(), entry.getKey().source()), entry -> entry.getValue().value()));
+    }
+
+
+    /**
      * Returns the associated paths to the mcast group.
      *
      * @param mcastIp the group ip
@@ -1813,17 +1874,11 @@
     @Deprecated
     public Map<ConnectPoint, List<ConnectPoint>> getMcastPaths(IpAddress mcastIp) {
         Map<ConnectPoint, List<ConnectPoint>> mcastPaths = Maps.newHashMap();
-        // Get the source
         ConnectPoint source = mcastUtils.getSource(mcastIp);
-        // Source cannot be null, we don't know the starting point
         if (source != null) {
-            // Init steps
             Set<DeviceId> visited = Sets.newHashSet();
-            List<ConnectPoint> currentPath = Lists.newArrayList(
-                    source
-            );
-            // Build recursively the mcast paths
-            buildMcastPaths(source.deviceId(), visited, mcastPaths, currentPath, mcastIp);
+            List<ConnectPoint> currentPath = Lists.newArrayList(source);
+            buildMcastPaths(source.deviceId(), visited, mcastPaths, currentPath, mcastIp, source);
         }
         return mcastPaths;
     }
@@ -1838,25 +1893,17 @@
     public Multimap<ConnectPoint, List<ConnectPoint>> getMcastTrees(IpAddress mcastIp,
                                                                     ConnectPoint sourcecp) {
         Multimap<ConnectPoint, List<ConnectPoint>> mcastTrees = HashMultimap.create();
-        // Get the sources
         Set<ConnectPoint> sources = mcastUtils.getSources(mcastIp);
-
-        // If we are providing the source, let's filter out
         if (sourcecp != null) {
             sources = sources.stream()
-                    .filter(source -> source.equals(sourcecp))
-                    .collect(Collectors.toSet());
+                    .filter(source -> source.equals(sourcecp)).collect(Collectors.toSet());
         }
-
-        // Source cannot be null, we don't know the starting point
         if (!sources.isEmpty()) {
             sources.forEach(source -> {
-                // Init steps
                 Map<ConnectPoint, List<ConnectPoint>> mcastPaths = Maps.newHashMap();
                 Set<DeviceId> visited = Sets.newHashSet();
                 List<ConnectPoint> currentPath = Lists.newArrayList(source);
-                // Build recursively the mcast paths
-                buildMcastPaths(source.deviceId(), visited, mcastPaths, currentPath, mcastIp);
+                buildMcastPaths(source.deviceId(), visited, mcastPaths, currentPath, mcastIp, source);
                 mcastPaths.forEach(mcastTrees::put);
             });
         }
@@ -1871,29 +1918,28 @@
      * @param mcastPaths the current mcast paths
      * @param currentPath the current path
      * @param mcastIp the group ip
+     * @param source the source
      */
     private void buildMcastPaths(DeviceId toVisit, Set<DeviceId> visited,
                                  Map<ConnectPoint, List<ConnectPoint>> mcastPaths,
-                                 List<ConnectPoint> currentPath, IpAddress mcastIp) {
-        // If we have visited the node to visit
-        // there is a loop
+                                 List<ConnectPoint> currentPath, IpAddress mcastIp,
+                                 ConnectPoint source) {
+        // If we have visited the node to visit there is a loop
         if (visited.contains(toVisit)) {
             return;
         }
         // Visit next-hop
         visited.add(toVisit);
-        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, toVisit);
+        VlanId assignedVlan = mcastUtils.assignedVlan(toVisit.equals(source.deviceId()) ? source : null);
+        McastStoreKey mcastStoreKey = new McastStoreKey(mcastIp, toVisit, assignedVlan);
         // Looking for next-hops
         if (mcastNextObjStore.containsKey(mcastStoreKey)) {
-            // Build egress connectpoints
+            // Build egress connect points, get ports and build relative cps
             NextObjective nextObjective = mcastNextObjStore.get(mcastStoreKey).value();
-            // Get Ports
             Set<PortNumber> outputPorts = mcastUtils.getPorts(nextObjective.next());
-            // Build relative cps
             ImmutableSet.Builder<ConnectPoint> cpBuilder = ImmutableSet.builder();
             outputPorts.forEach(portNumber -> cpBuilder.add(new ConnectPoint(toVisit, portNumber)));
             Set<ConnectPoint> egressPoints = cpBuilder.build();
-            // Define other variables for the next steps
             Set<Link> egressLinks;
             List<ConnectPoint> newCurrentPath;
             Set<DeviceId> newVisited;
@@ -1905,20 +1951,16 @@
                     // Add the connect points to the path
                     newCurrentPath = Lists.newArrayList(currentPath);
                     newCurrentPath.add(0, egressPoint);
-                    // Save in the map
                     mcastPaths.put(egressPoint, newCurrentPath);
                 } else {
                     newVisited = Sets.newHashSet(visited);
                     // Iterate over the egress links for the next hops
                     for (Link egressLink : egressLinks) {
-                        // Update to visit
                         newToVisit = egressLink.dst().deviceId();
-                        // Add the connect points to the path
                         newCurrentPath = Lists.newArrayList(currentPath);
                         newCurrentPath.add(0, egressPoint);
                         newCurrentPath.add(0, egressLink.dst());
-                        // Go to the next hop
-                        buildMcastPaths(newToVisit, newVisited, mcastPaths, newCurrentPath, mcastIp);
+                        buildMcastPaths(newToVisit, newVisited, mcastPaths, newCurrentPath, mcastIp, source);
                     }
                 }
             }
diff --git a/app/src/main/java/org/onosproject/segmentrouting/storekey/McastStoreKey.java b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastRoleStoreKey.java
similarity index 60%
copy from app/src/main/java/org/onosproject/segmentrouting/storekey/McastStoreKey.java
copy to app/src/main/java/org/onosproject/segmentrouting/mcast/McastRoleStoreKey.java
index 6891b77..26f21f2 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/storekey/McastStoreKey.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastRoleStoreKey.java
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016-present Open Networking Foundation
+ * Copyright 2018-present Open Networking Foundation
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -14,34 +14,49 @@
  * limitations under the License.
  */
 
-package org.onosproject.segmentrouting.storekey;
+package org.onosproject.segmentrouting.mcast;
 
 import org.onlab.packet.IpAddress;
+import org.onosproject.net.ConnectPoint;
 import org.onosproject.net.DeviceId;
+
+import java.util.Objects;
+
 import static com.google.common.base.MoreObjects.toStringHelper;
 import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
-import java.util.Objects;
 
 /**
- * Key of multicast next objective store.
+ * Key of multicast role store.
  */
-public class McastStoreKey {
+public class McastRoleStoreKey {
+    // Identify role using group address, deviceId and source
     private final IpAddress mcastIp;
     private final DeviceId deviceId;
+    private final ConnectPoint source;
 
     /**
-     * Constructs the key of multicast next objective store.
+     * Constructs the key of multicast role store.
      *
      * @param mcastIp multicast group IP address
      * @param deviceId device ID
+     * @param source source connect point
      */
-    public McastStoreKey(IpAddress mcastIp, DeviceId deviceId) {
+    public McastRoleStoreKey(IpAddress mcastIp, DeviceId deviceId, ConnectPoint source) {
         checkNotNull(mcastIp, "mcastIp cannot be null");
         checkNotNull(deviceId, "deviceId cannot be null");
+        checkNotNull(source, "source cannot be null");
         checkArgument(mcastIp.isMulticast(), "mcastIp must be a multicast address");
         this.mcastIp = mcastIp;
         this.deviceId = deviceId;
+        this.source = source;
+    }
+
+    // Constructor for serialization
+    private McastRoleStoreKey() {
+        this.mcastIp = null;
+        this.deviceId = null;
+        this.source = null;
     }
 
     /**
@@ -62,23 +77,33 @@
         return deviceId;
     }
 
+    /**
+     * Returns the source connect point of this key.
+     *
+     * @return the source connect point
+     */
+    public ConnectPoint source() {
+        return source;
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) {
             return true;
         }
-        if (!(o instanceof McastStoreKey)) {
+        if (!(o instanceof McastRoleStoreKey)) {
             return false;
         }
-        McastStoreKey that =
-                (McastStoreKey) o;
-        return (Objects.equals(this.mcastIp, that.mcastIp) &&
-                Objects.equals(this.deviceId, that.deviceId));
+        final McastRoleStoreKey that = (McastRoleStoreKey) o;
+
+        return Objects.equals(this.mcastIp, that.mcastIp) &&
+                Objects.equals(this.deviceId, that.deviceId) &&
+                Objects.equals(this.source, that.source);
     }
 
     @Override
     public int hashCode() {
-         return Objects.hash(mcastIp, deviceId);
+        return Objects.hash(mcastIp, deviceId, source);
     }
 
     @Override
@@ -86,6 +111,7 @@
         return toStringHelper(getClass())
                 .add("mcastIp", mcastIp)
                 .add("deviceId", deviceId)
+                .add("source", source)
                 .toString();
     }
 }
diff --git a/app/src/main/java/org/onosproject/segmentrouting/mcast/McastRoleStoreKeySerializer.java b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastRoleStoreKeySerializer.java
new file mode 100644
index 0000000..aec7278
--- /dev/null
+++ b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastRoleStoreKeySerializer.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2018-present Open Networking Foundation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.onosproject.segmentrouting.mcast;
+
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.onlab.packet.IpAddress;
+import org.onosproject.net.ConnectPoint;
+import org.onosproject.net.DeviceId;
+
+/**
+ * Custom serializer for {@link McastRoleStoreKey}.
+ */
+class McastRoleStoreKeySerializer extends Serializer<McastRoleStoreKey> {
+
+    /**
+     * Creates {@link McastRoleStoreKeySerializer} serializer instance.
+     */
+    McastRoleStoreKeySerializer() {
+        // non-null, immutable
+        super(false, true);
+    }
+
+    @Override
+    public void write(Kryo kryo, Output output, McastRoleStoreKey object) {
+        kryo.writeClassAndObject(output, object.mcastIp());
+        output.writeString(object.deviceId().toString());
+        kryo.writeClassAndObject(output, object.source());
+    }
+
+    @Override
+    public McastRoleStoreKey read(Kryo kryo, Input input, Class<McastRoleStoreKey> type) {
+        IpAddress mcastIp = (IpAddress) kryo.readClassAndObject(input);
+        final String str = input.readString();
+        DeviceId deviceId =  DeviceId.deviceId(str);
+        ConnectPoint source = (ConnectPoint) kryo.readClassAndObject(input);
+        return new McastRoleStoreKey(mcastIp, deviceId, source);
+    }
+}
diff --git a/app/src/main/java/org/onosproject/segmentrouting/storekey/McastStoreKey.java b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastStoreKey.java
similarity index 61%
rename from app/src/main/java/org/onosproject/segmentrouting/storekey/McastStoreKey.java
rename to app/src/main/java/org/onosproject/segmentrouting/mcast/McastStoreKey.java
index 6891b77..aa32797 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/storekey/McastStoreKey.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastStoreKey.java
@@ -1,5 +1,5 @@
 /*
- * Copyright 2016-present Open Networking Foundation
+ * Copyright 2018-present Open Networking Foundation
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -14,9 +14,10 @@
  * limitations under the License.
  */
 
-package org.onosproject.segmentrouting.storekey;
+package org.onosproject.segmentrouting.mcast;
 
 import org.onlab.packet.IpAddress;
+import org.onlab.packet.VlanId;
 import org.onosproject.net.DeviceId;
 import static com.google.common.base.MoreObjects.toStringHelper;
 import static com.google.common.base.Preconditions.checkArgument;
@@ -27,21 +28,52 @@
  * Key of multicast next objective store.
  */
 public class McastStoreKey {
+    // Identify a flow using group address, deviceId, and assigned vlan
     private final IpAddress mcastIp;
     private final DeviceId deviceId;
+    private final VlanId vlanId;
 
     /**
      * Constructs the key of multicast next objective store.
      *
      * @param mcastIp multicast group IP address
      * @param deviceId device ID
+     *
+     * @deprecated in 1.12 ("Magpie") release.
      */
+    @Deprecated
     public McastStoreKey(IpAddress mcastIp, DeviceId deviceId) {
         checkNotNull(mcastIp, "mcastIp cannot be null");
         checkNotNull(deviceId, "deviceId cannot be null");
         checkArgument(mcastIp.isMulticast(), "mcastIp must be a multicast address");
         this.mcastIp = mcastIp;
         this.deviceId = deviceId;
+        this.vlanId = null;
+    }
+
+    /**
+     * Constructs the key of multicast next objective store.
+     *
+     * @param mcastIp multicast group IP address
+     * @param deviceId device ID
+     * @param vlanId vlan id
+     */
+    public McastStoreKey(IpAddress mcastIp, DeviceId deviceId, VlanId vlanId) {
+        checkNotNull(mcastIp, "mcastIp cannot be null");
+        checkNotNull(deviceId, "deviceId cannot be null");
+        checkNotNull(vlanId, "vlan id cannot be null");
+        checkArgument(mcastIp.isMulticast(), "mcastIp must be a multicast address");
+        this.mcastIp = mcastIp;
+        this.deviceId = deviceId;
+        // FIXME probably we should avoid not valid values
+        this.vlanId = vlanId;
+    }
+
+    // Constructor for serialization
+    private McastStoreKey() {
+        this.mcastIp = null;
+        this.deviceId = null;
+        this.vlanId = null;
     }
 
     /**
@@ -62,6 +94,15 @@
         return deviceId;
     }
 
+    /**
+     * Returns the vlan ID of this key.
+     *
+     * @return vlan ID
+     */
+    public VlanId vlanId() {
+        return vlanId;
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) {
@@ -73,12 +114,13 @@
         McastStoreKey that =
                 (McastStoreKey) o;
         return (Objects.equals(this.mcastIp, that.mcastIp) &&
-                Objects.equals(this.deviceId, that.deviceId));
+                Objects.equals(this.deviceId, that.deviceId) &&
+                Objects.equals(this.vlanId, that.vlanId));
     }
 
     @Override
     public int hashCode() {
-         return Objects.hash(mcastIp, deviceId);
+         return Objects.hash(mcastIp, deviceId, vlanId);
     }
 
     @Override
@@ -86,6 +128,7 @@
         return toStringHelper(getClass())
                 .add("mcastIp", mcastIp)
                 .add("deviceId", deviceId)
+                .add("vlanId", vlanId)
                 .toString();
     }
 }
diff --git a/app/src/main/java/org/onosproject/segmentrouting/mcast/McastStoreKeySerializer.java b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastStoreKeySerializer.java
new file mode 100644
index 0000000..c32b0ba
--- /dev/null
+++ b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastStoreKeySerializer.java
@@ -0,0 +1,56 @@
+/*
+ * Copyright 2018-present Open Networking Foundation
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.onosproject.segmentrouting.mcast;
+
+
+import com.esotericsoftware.kryo.Kryo;
+import com.esotericsoftware.kryo.Serializer;
+import com.esotericsoftware.kryo.io.Input;
+import com.esotericsoftware.kryo.io.Output;
+import org.onlab.packet.IpAddress;
+import org.onlab.packet.VlanId;
+import org.onosproject.net.DeviceId;
+
+/**
+ * Custom serializer for {@link McastStoreKey}.
+ */
+class McastStoreKeySerializer extends Serializer<McastStoreKey> {
+
+    /**
+     * Creates {@link McastStoreKeySerializer} serializer instance.
+     */
+    McastStoreKeySerializer() {
+        // non-null, immutable
+        super(false, true);
+    }
+
+    @Override
+    public void write(Kryo kryo, Output output, McastStoreKey object) {
+        kryo.writeClassAndObject(output, object.mcastIp());
+        output.writeString(object.deviceId().toString());
+        kryo.writeClassAndObject(output, object.vlanId());
+    }
+
+    @Override
+    public McastStoreKey read(Kryo kryo, Input input, Class<McastStoreKey> type) {
+        IpAddress mcastIp = (IpAddress) kryo.readClassAndObject(input);
+        final String str = input.readString();
+        DeviceId deviceId =  DeviceId.deviceId(str);
+        VlanId vlanId = (VlanId) kryo.readClassAndObject(input);
+        return new McastStoreKey(mcastIp, deviceId, vlanId);
+    }
+}
diff --git a/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
index 2b13370..604f868 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastUtils.java
@@ -260,7 +260,7 @@
      * @return sources connect points or empty set if not found
      */
     Set<ConnectPoint> getSources(IpAddress mcastIp) {
-        // FIXME we should support different types of routes
+        // TODO we should support different types of routes
         McastRoute mcastRoute = srManager.multicastRouteService.getRoutes().stream()
                 .filter(mcastRouteInternal -> mcastRouteInternal.group().equals(mcastIp))
                 .findFirst().orElse(null);
@@ -275,7 +275,7 @@
      * @return map of sinks or empty map if not found
      */
     Map<HostId, Set<ConnectPoint>> getSinks(IpAddress mcastIp) {
-        // FIXME we should support different types of routes
+        // TODO we should support different types of routes
         McastRoute mcastRoute = srManager.multicastRouteService.getRoutes().stream()
                 .filter(mcastRouteInternal -> mcastRouteInternal.group().equals(mcastIp))
                 .findFirst().orElse(null);
@@ -292,7 +292,7 @@
      * @return the map of the sinks affected
      */
     Map<HostId, Set<ConnectPoint>> getAffectedSinks(DeviceId egressDevice,
-                                                            IpAddress mcastIp) {
+                                                    IpAddress mcastIp) {
         return getSinks(mcastIp).entrySet()
                 .stream()
                 .filter(hostIdSetEntry -> hostIdSetEntry.getValue().stream()
diff --git a/app/src/main/resources/OSGI-INF/blueprint/shell-config.xml b/app/src/main/resources/OSGI-INF/blueprint/shell-config.xml
index 1c8610a..d575413 100644
--- a/app/src/main/resources/OSGI-INF/blueprint/shell-config.xml
+++ b/app/src/main/resources/OSGI-INF/blueprint/shell-config.xml
@@ -79,6 +79,13 @@
             </optional-completers>
         </command>
         <command>
+            <action class="org.onosproject.segmentrouting.cli.McastRoleListCommand"/>
+            <optional-completers>
+                <entry key="-gAddr" value-ref="mcastGroupCompleter"/>
+                <entry key="-src" value-ref="connectpointCompleter"/>
+            </optional-completers>
+        </command>
+        <command>
             <action class="org.onosproject.segmentrouting.cli.McastTreeListCommand"/>
             <optional-completers>
                 <entry key="-gAddr" value-ref="mcastGroupCompleter"/>