ONOS-2379 Reactive Fwd improvements prune bad flows from switches when a link goes down

Change-Id: I27de61abc6225d12fd4ba5fa36d58ec4061f9db5
diff --git a/apps/fwd/src/main/java/org/onosproject/fwd/ReactiveForwarding.java b/apps/fwd/src/main/java/org/onosproject/fwd/ReactiveForwarding.java
index c557dbb..3225f81 100644
--- a/apps/fwd/src/main/java/org/onosproject/fwd/ReactiveForwarding.java
+++ b/apps/fwd/src/main/java/org/onosproject/fwd/ReactiveForwarding.java
@@ -15,6 +15,7 @@
  */
 package org.onosproject.fwd;
 
+import com.google.common.collect.ImmutableSet;
 import org.apache.felix.scr.annotations.Activate;
 import org.apache.felix.scr.annotations.Component;
 import org.apache.felix.scr.annotations.Deactivate;
@@ -29,35 +30,51 @@
 import org.onlab.packet.IPv6;
 import org.onlab.packet.Ip4Prefix;
 import org.onlab.packet.Ip6Prefix;
+import org.onlab.packet.MacAddress;
 import org.onlab.packet.TCP;
 import org.onlab.packet.UDP;
 import org.onlab.packet.VlanId;
 import org.onosproject.cfg.ComponentConfigService;
 import org.onosproject.core.ApplicationId;
 import org.onosproject.core.CoreService;
+import org.onosproject.event.Event;
+import org.onosproject.net.ConnectPoint;
+import org.onosproject.net.DeviceId;
 import org.onosproject.net.Host;
 import org.onosproject.net.HostId;
+import org.onosproject.net.Link;
 import org.onosproject.net.Path;
 import org.onosproject.net.PortNumber;
 import org.onosproject.net.flow.DefaultTrafficSelector;
 import org.onosproject.net.flow.DefaultTrafficTreatment;
+import org.onosproject.net.flow.FlowEntry;
+import org.onosproject.net.flow.FlowRule;
 import org.onosproject.net.flow.FlowRuleService;
 import org.onosproject.net.flow.TrafficSelector;
 import org.onosproject.net.flow.TrafficTreatment;
+import org.onosproject.net.flow.criteria.Criterion;
+import org.onosproject.net.flow.criteria.EthCriterion;
+import org.onosproject.net.flow.instructions.Instruction;
+import org.onosproject.net.flow.instructions.Instructions;
 import org.onosproject.net.flowobjective.DefaultForwardingObjective;
 import org.onosproject.net.flowobjective.FlowObjectiveService;
 import org.onosproject.net.flowobjective.ForwardingObjective;
 import org.onosproject.net.host.HostService;
+import org.onosproject.net.link.LinkEvent;
 import org.onosproject.net.packet.InboundPacket;
 import org.onosproject.net.packet.PacketContext;
 import org.onosproject.net.packet.PacketPriority;
 import org.onosproject.net.packet.PacketProcessor;
 import org.onosproject.net.packet.PacketService;
+import org.onosproject.net.topology.TopologyEvent;
+import org.onosproject.net.topology.TopologyListener;
 import org.onosproject.net.topology.TopologyService;
 import org.osgi.service.component.ComponentContext;
 import org.slf4j.Logger;
 
 import java.util.Dictionary;
+import java.util.List;
+import java.util.Objects;
 import java.util.Set;
 
 import static com.google.common.base.Strings.isNullOrEmpty;
@@ -155,16 +172,21 @@
                     "default is false")
     private boolean matchIcmpFields = false;
 
+
     @Property(name = "ignoreIPv4Multicast", boolValue = false,
             label = "Ignore (do not forward) IPv4 multicast packets; default is false")
     private boolean ignoreIpv4McastPackets = false;
 
+    private final TopologyListener topologyListener = new InternalTopologyListener();
+
+
     @Activate
     public void activate(ComponentContext context) {
         cfgService.registerProperties(getClass());
         appId = coreService.registerApplication("org.onosproject.fwd");
 
         packetService.addProcessor(processor, PacketProcessor.ADVISOR_MAX + 2);
+        topologyService.addListener(topologyListener);
         readComponentConfiguration(context);
         requestIntercepts();
 
@@ -177,6 +199,7 @@
         withdrawIntercepts();
         flowRuleService.removeFlowRulesById(appId);
         packetService.removeProcessor(processor);
+        topologyService.removeListener(topologyListener);
         processor = null;
         log.info("Stopped");
     }
@@ -383,6 +406,7 @@
         public void process(PacketContext context) {
             // Stop processing if the packet has been handled, since we
             // can't do any more to it.
+
             if (context.isHandled()) {
                 return;
             }
@@ -646,4 +670,161 @@
             packetOut(context, portNumber);
         }
     }
+
+    private class InternalTopologyListener implements TopologyListener {
+        @Override
+        public void event(TopologyEvent event) {
+            List<Event> reasons = event.reasons();
+            if (reasons != null) {
+                reasons.forEach(re -> {
+                    if (re instanceof LinkEvent) {
+                        LinkEvent le = (LinkEvent) re;
+                        if (le.type() == LinkEvent.Type.LINK_REMOVED) {
+                            fixBlackhole(le.subject().src());
+                        }
+                    }
+                });
+            }
+        }
+    }
+
+    private void fixBlackhole(ConnectPoint egress) {
+        Set<FlowEntry> rules =  getFlowRulesFrom(egress);
+        Set<SrcDstPair> pairs = findSrcDstPairs(rules);
+
+        for (SrcDstPair sd: pairs) {
+            // get the edge deviceID for the src host
+            DeviceId srcId = hostService.getHost(HostId.hostId(sd.src)).location().deviceId();
+            DeviceId dstId = hostService.getHost(HostId.hostId(sd.dst)).location().deviceId();
+            log.trace("SRC ID is " + srcId + ", DST ID is " + dstId);
+
+            cleanFlowRules(sd, egress.deviceId());
+
+            Set<Path> shortestPaths =
+                    topologyService.getPaths(topologyService.currentTopology(), egress.deviceId(), srcId);
+            backTrackBadNodes(shortestPaths, dstId, sd);
+        }
+    }
+
+    // Backtracks from link down event to remove flows that lead to blackhole
+    private void backTrackBadNodes(Set<Path> shortestPaths, DeviceId dstId, SrcDstPair sd) {
+        for (Path p: shortestPaths) {
+            List<Link> pathLinks = p.links();
+            for (int i = 0; i < pathLinks.size(); i = i + 1) {
+                Link curLink = pathLinks.get(i);
+                DeviceId curDevice = curLink.src().deviceId();
+                    log.trace("Currently inspecting device: " + curDevice);
+
+                // skipping the first link because this link's src has already been pruned beforehand
+                if (i != 0) {
+                    cleanFlowRules(sd, curDevice);
+                }
+
+                Set<Path> pathsFromCurDevice = topologyService.getPaths(topologyService.currentTopology(),
+                        curDevice, dstId);
+                if (pickForwardPath(pathsFromCurDevice, curLink.src().port()) != null) {
+                    break;
+                } else {
+                    if (i + 1 == pathLinks.size()) {
+                        cleanFlowRules(sd, curLink.dst().deviceId());
+                    }
+                }
+            }
+        }
+    }
+
+    // Removes flow rules off specified device with specific SrcDstPair
+    private void cleanFlowRules(SrcDstPair pair, DeviceId id) {
+        log.trace("Searching for flow rules to remove from: " + id);
+        log.trace("Removing flows w/ SRC=" + pair.src + ", DST=" + pair.dst);
+        for (FlowEntry r : flowRuleService.getFlowEntries(id)) {
+            boolean matchesSrc = false, matchesDst = false;
+            for (Instruction i : r.treatment().allInstructions()) {
+                if (i.type() == Instruction.Type.OUTPUT) {
+                    //if the flow has matching src and dst
+                    for (Criterion cr : r.selector().criteria()) {
+                        if (cr.type() == Criterion.Type.ETH_DST) {
+                            if (((EthCriterion) cr).mac().equals(pair.dst)) {
+                                matchesDst = true;
+                            }
+                        } else if (cr.type() == Criterion.Type.ETH_SRC) {
+                            if (((EthCriterion) cr).mac().equals(pair.src)) {
+                                matchesSrc = true;
+                            }
+                        }
+                    }
+                }
+            }
+            if (matchesDst && matchesSrc) {
+                log.trace("Removed flow rule from device: " + id);
+                flowRuleService.removeFlowRules((FlowRule) r);
+            }
+        }
+
+    }
+
+    // Returns a set of src/dst MAC pairs extracted from the specified set of flow entries
+    private Set<SrcDstPair> findSrcDstPairs(Set<FlowEntry> rules) {
+        ImmutableSet.Builder<SrcDstPair> builder = ImmutableSet.builder();
+        for (FlowEntry r: rules) {
+            MacAddress src = null, dst = null;
+            for (Criterion cr: r.selector().criteria()) {
+                if (cr.type() == Criterion.Type.ETH_DST) {
+                    dst = ((EthCriterion) cr).mac();
+                } else if (cr.type() == Criterion.Type.ETH_SRC) {
+                    src = ((EthCriterion) cr).mac();
+                }
+            }
+            builder.add(new SrcDstPair(src, dst));
+        }
+        return builder.build();
+    }
+
+    // Returns set of flowEntries which were created by this application and which egress from the
+    // specified connection port
+    private Set<FlowEntry> getFlowRulesFrom(ConnectPoint egress) {
+        ImmutableSet.Builder<FlowEntry> builder = ImmutableSet.builder();
+        flowRuleService.getFlowEntries(egress.deviceId()).forEach(r -> {
+            if (r.appId() == appId.id()) {
+                r.treatment().allInstructions().forEach(i -> {
+                    if (i.type() == Instruction.Type.OUTPUT) {
+                        if (((Instructions.OutputInstruction) i).port().equals(egress.port())) {
+                            builder.add(r);
+                        }
+                    }
+                });
+            }
+        });
+
+        return builder.build();
+    }
+
+    // Wrapper class for a source and destination pair of MAC addresses
+    private final class SrcDstPair {
+        final MacAddress src;
+        final MacAddress dst;
+
+        private SrcDstPair(MacAddress src, MacAddress dst) {
+            this.src = src;
+            this.dst = dst;
+        }
+
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (o == null || getClass() != o.getClass()) {
+                return false;
+            }
+            SrcDstPair that = (SrcDstPair) o;
+            return Objects.equals(src, that.src) &&
+                    Objects.equals(dst, that.dst);
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(src, dst);
+        }
+    }
 }