[CORD-2828] Multicast support for H-AGG

Change-Id: I637465bcff6454f414c349a0ab054d66e3d17a05
(cherry picked from commit 4e9a7ce2f70a21a71f828335a6041101f2bf833e)
diff --git a/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
index d35d253..a791ecb 100644
--- a/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
+++ b/app/src/main/java/org/onosproject/segmentrouting/mcast/McastHandler.java
@@ -86,7 +86,6 @@
 import java.util.concurrent.locks.ReentrantLock;
 import java.util.stream.Collectors;
 
-import static com.google.common.base.Preconditions.checkState;
 import static java.util.concurrent.Executors.newScheduledThreadPool;
 import static org.onlab.util.Tools.groupedThreads;
 import static org.onosproject.net.flow.criteria.Criterion.Type.VLAN_VID;
@@ -381,8 +380,7 @@
             // Find out the ingress, transit and egress device of the affected group
             DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
                     .stream().findAny().orElse(null);
-            DeviceId transitDevice = getDevice(mcastIp, TRANSIT)
-                    .stream().findAny().orElse(null);
+            Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
             Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS);
 
             // Verify leadership on the operation
@@ -391,15 +389,17 @@
                 return;
             }
 
-            // If there are egress devices, sinks could be only on the ingress
+            // If there are no egress devices, sinks could be only on the ingress
             if (!egressDevices.isEmpty()) {
                 egressDevices.forEach(
                         deviceId -> removeGroupFromDevice(deviceId, mcastIp, assignedVlan(null))
                 );
             }
-            // Transit could be null
-            if (transitDevice != null) {
-                removeGroupFromDevice(transitDevice, mcastIp, assignedVlan(null));
+            // Transit could be empty if sinks are on the ingress
+            if (!transitDevices.isEmpty()) {
+                transitDevices.forEach(
+                        deviceId -> removeGroupFromDevice(deviceId, mcastIp, assignedVlan(null))
+                );
             }
             // Ingress device should be not null
             if (ingressDevice != null) {
@@ -514,9 +514,12 @@
             Optional<Path> mcastPath = getPath(source.deviceId(), sink.deviceId(), mcastIp);
             if (mcastPath.isPresent()) {
                 List<Link> links = mcastPath.get().links();
-                checkState(links.size() == 2,
-                           "Path in leaf-spine topology should always be two hops: ", links);
 
+                // Setup mcast role for ingress
+                mcastRoleStore.put(new McastStoreKey(mcastIp, source.deviceId()),
+                                   INGRESS);
+
+                // Setup properly the transit
                 links.forEach(link -> {
                     addPortToDevice(link.src().deviceId(), link.src().port(), mcastIp,
                                     assignedVlan(link.src().deviceId().equals(source.deviceId()) ? source : null));
@@ -524,14 +527,15 @@
                                       assignedVlan(null), mcastIp, null);
                 });
 
+                // Setup mcast role for the transit
+                links.stream()
+                        .filter(link -> !link.dst().deviceId().equals(sink.deviceId()))
+                        .forEach(link -> mcastRoleStore.put(new McastStoreKey(mcastIp, link.dst().deviceId()),
+                                                            TRANSIT));
+
                 // Process the egress device
                 addPortToDevice(sink.deviceId(), sink.port(), mcastIp, assignedVlan(null));
-
-                // Setup mcast roles
-                mcastRoleStore.put(new McastStoreKey(mcastIp, source.deviceId()),
-                                   INGRESS);
-                mcastRoleStore.put(new McastStoreKey(mcastIp, links.get(0).dst().deviceId()),
-                                   TRANSIT);
+                // Setup mcast role for egress
                 mcastRoleStore.put(new McastStoreKey(mcastIp, sink.deviceId()),
                                    EGRESS);
             } else {
@@ -561,16 +565,16 @@
                 // Find out the ingress, transit and egress device of affected group
                 DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
                         .stream().findAny().orElse(null);
-                DeviceId transitDevice = getDevice(mcastIp, TRANSIT)
-                        .stream().findAny().orElse(null);
+                Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
                 Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS);
                 ConnectPoint source = getSource(mcastIp);
 
-                // Do not proceed if any of these info is missing
-                if (ingressDevice == null || transitDevice == null
-                        || egressDevices == null || source == null) {
-                    log.warn("Missing ingress {}, transit {}, egress {} devices or source {}",
-                             ingressDevice, transitDevice, egressDevices, source);
+                // Do not proceed if ingress device or source of this group are missing
+                // If sinks are in other leafs, we have ingress, transit, egress, and source
+                // If sinks are in the same leaf, we have just ingress and source
+                if (ingressDevice == null || source == null) {
+                    log.warn("Missing ingress {} or source {} for group {}",
+                             ingressDevice, source, mcastIp);
                     return;
                 }
 
@@ -582,17 +586,11 @@
                 }
 
                 // Remove entire transit
-                removeGroupFromDevice(transitDevice, mcastIp, assignedVlan(null));
+                transitDevices.forEach(transitDevice ->
+                                removeGroupFromDevice(transitDevice, mcastIp, assignedVlan(null)));
 
-                // Remove transit-facing port on ingress device
-                PortNumber ingressTransitPort = ingressTransitPort(mcastIp);
-                if (ingressTransitPort != null) {
-                    boolean isLast = removePortFromDevice(ingressDevice, ingressTransitPort,
-                                                          mcastIp, assignedVlan(source));
-                    if (isLast) {
-                        mcastRoleStore.remove(new McastStoreKey(mcastIp, ingressDevice));
-                    }
-                }
+                // Remove transit-facing ports on the ingress device
+                removeIngressTransitPorts(mcastIp, ingressDevice, source);
 
                 // Construct a new path for each egress device
                 egressDevices.forEach(egressDevice -> {
@@ -629,8 +627,7 @@
                 // Find out the ingress, transit and egress device of affected group
                 DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
                         .stream().findAny().orElse(null);
-                DeviceId transitDevice = getDevice(mcastIp, TRANSIT)
-                        .stream().findAny().orElse(null);
+                Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
                 Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS);
                 ConnectPoint source = getSource(mcastIp);
 
@@ -650,9 +647,10 @@
                 }
 
                 // If it exists, we have to remove it in any case
-                if (transitDevice != null) {
+                if (!transitDevices.isEmpty()) {
                     // Remove entire transit
-                    removeGroupFromDevice(transitDevice, mcastIp, assignedVlan(null));
+                    transitDevices.forEach(transitDevice ->
+                                    removeGroupFromDevice(transitDevice, mcastIp, assignedVlan(null)));
                 }
                 // If the ingress is down
                 if (ingressDevice.equals(deviceDown)) {
@@ -667,18 +665,9 @@
                     }
                 } else {
                     // Egress or transit could be down at this point
-                    // Get the ingress-transit port if it exists
-                    PortNumber ingressTransitPort = ingressTransitPort(mcastIp);
-                    if (ingressTransitPort != null) {
-                        // Remove transit-facing port on ingress device
-                        boolean isLast = removePortFromDevice(ingressDevice, ingressTransitPort,
-                                                              mcastIp, assignedVlan(source));
-                        // There are no further ports
-                        if (isLast) {
-                            // Remove entire ingress
-                            mcastRoleStore.remove(new McastStoreKey(mcastIp, ingressDevice));
-                        }
-                    }
+                    // Get the ingress-transit ports if they exist
+                    removeIngressTransitPorts(mcastIp, ingressDevice, source);
+
                     // One of the egress device is down
                     if (egressDevices.contains(deviceDown)) {
                         // Remove entire device down
@@ -713,6 +702,27 @@
     }
 
     /**
+     * Utility method to remove all the ingress transit ports.
+     *
+     * @param mcastIp the group ip
+     * @param ingressDevice the ingress device for this group
+     * @param source the source connect point
+     */
+    private void removeIngressTransitPorts(IpAddress mcastIp, DeviceId ingressDevice,
+                                           ConnectPoint source) {
+        Set<PortNumber> ingressTransitPorts = ingressTransitPort(mcastIp);
+        ingressTransitPorts.forEach(ingressTransitPort -> {
+            if (ingressTransitPort != null) {
+                boolean isLast = removePortFromDevice(ingressDevice, ingressTransitPort,
+                                                      mcastIp, assignedVlan(source));
+                if (isLast) {
+                    mcastRoleStore.remove(new McastStoreKey(mcastIp, ingressDevice));
+                }
+            }
+        });
+    }
+
+    /**
      * Adds filtering objective for given device and port.
      *
      * @param deviceId device ID
@@ -910,6 +920,11 @@
     private void installPath(IpAddress mcastIp, ConnectPoint source, Path mcastPath) {
         // Get Links
         List<Link> links = mcastPath.links();
+
+        // Setup new ingress mcast role
+        mcastRoleStore.put(new McastStoreKey(mcastIp, links.get(0).src().deviceId()),
+                           INGRESS);
+
         // For each link, modify the next on the source device adding the src port
         // and a new filter objective on the destination port
         links.forEach(link -> {
@@ -918,12 +933,12 @@
             addFilterToDevice(link.dst().deviceId(), link.dst().port(),
                               assignedVlan(null), mcastIp, null);
         });
-        // Setup new transit mcast role
-        mcastRoleStore.put(new McastStoreKey(mcastIp, links.get(0).dst().deviceId()),
-                           TRANSIT);
-        // Setup new ingress mcast role
-        mcastRoleStore.put(new McastStoreKey(mcastIp, links.get(0).src().deviceId()),
-                           INGRESS);
+
+        // Setup mcast role for the transit
+        links.stream()
+                .filter(link -> !link.src().deviceId().equals(source.deviceId()))
+                .forEach(link -> mcastRoleStore.put(new McastStoreKey(mcastIp, link.src().deviceId()),
+                                                    TRANSIT));
     }
 
     /**
@@ -1307,24 +1322,25 @@
      * @param mcastIp multicast IP
      * @return spine-facing port on ingress device
      */
-    private PortNumber ingressTransitPort(IpAddress mcastIp) {
+    private Set<PortNumber> ingressTransitPort(IpAddress mcastIp) {
         DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
                 .stream().findAny().orElse(null);
+        ImmutableSet.Builder<PortNumber> portBuilder = ImmutableSet.builder();
         if (ingressDevice != null) {
             NextObjective nextObj = mcastNextObjStore
                     .get(new McastStoreKey(mcastIp, ingressDevice)).value();
             Set<PortNumber> ports = getPorts(nextObj.next());
-
+            // Let's find out all the ingress-transit ports
             for (PortNumber port : ports) {
                 // Spine-facing port should have no subnet and no xconnect
                 if (srManager.deviceConfiguration() != null &&
                         srManager.deviceConfiguration().getPortSubnets(ingressDevice, port).isEmpty() &&
                         !srManager.xConnectHandler.hasXConnect(new ConnectPoint(ingressDevice, port))) {
-                    return port;
+                    portBuilder.add(port);
                 }
             }
         }
-        return null;
+        return portBuilder.build();
     }
 
     /**
@@ -1480,8 +1496,7 @@
                         // and issue a check of the next objectives in place
                         DeviceId ingressDevice = getDevice(mcastIp, INGRESS)
                                 .stream().findAny().orElse(null);
-                        DeviceId transitDevice = getDevice(mcastIp, TRANSIT)
-                                .stream().findAny().orElse(null);
+                        Set<DeviceId> transitDevices = getDevice(mcastIp, TRANSIT);
                         Set<DeviceId> egressDevices = getDevice(mcastIp, EGRESS);
                         // Get source and sinks from Mcast Route Service and warn about errors
                         ConnectPoint source = getSource(mcastIp);
@@ -1509,8 +1524,8 @@
                         // Create the set of the devices to be processed
                         ImmutableSet.Builder<DeviceId> devicesBuilder = ImmutableSet.builder();
                         devicesBuilder.add(ingressDevice);
-                        if (transitDevice != null) {
-                            devicesBuilder.add(transitDevice);
+                        if (!transitDevices.isEmpty()) {
+                            devicesBuilder.addAll(transitDevices);
                         }
                         if (!egressDevices.isEmpty()) {
                             devicesBuilder.addAll(egressDevices);