Support list of nodes and list of links to avoid for AVOID policy
diff --git a/src/main/java/net/onrc/onos/apps/segmentrouting/ECMPShortestPathGraph.java b/src/main/java/net/onrc/onos/apps/segmentrouting/ECMPShortestPathGraph.java
index f8ba3eb..cb3ec9b 100644
--- a/src/main/java/net/onrc/onos/apps/segmentrouting/ECMPShortestPathGraph.java
+++ b/src/main/java/net/onrc/onos/apps/segmentrouting/ECMPShortestPathGraph.java
@@ -3,6 +3,7 @@
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.LinkedList;
+import java.util.List;
 
 import net.onrc.onos.core.intent.Path;
 import net.onrc.onos.core.topology.Link;
@@ -35,17 +36,24 @@
      * Constructor
      *
      * @param rootSwitch root of the BFS tree
+     * @param linkListToAvoid
+     * @param dpidListToAvoid
+     */
+    public ECMPShortestPathGraph(Switch rootSwitch, List<String> dpidListToAvoid, List<Link> linkListToAvoid) {
+        this.rootSwitch = rootSwitch;
+        calcECMPShortestPathGraph(dpidListToAvoid, linkListToAvoid);
+    }
+
+    /**
+     * Constructor
+     *
+     * @param rootSwitch root of the BFS tree
      */
     public ECMPShortestPathGraph(Switch rootSwitch) {
         this.rootSwitch = rootSwitch;
         calcECMPShortestPathGraph();
     }
 
-    public ECMPShortestPathGraph(Switch rootSw, Switch swToAvoid) {
-        this.rootSwitch = rootSw;
-        calcECMPShortestPathGraph(swToAvoid.getDpid());
-    }
-
     /**
      * Calculates the BFS tree using any provided constraints and Intents.
      */
@@ -130,18 +138,31 @@
     /**
      * Calculates the BFS tree using any provided constraints and Intents.
      */
-    private void calcECMPShortestPathGraph(Dpid avoid) {
+    private void calcECMPShortestPathGraph(List<String> dpidListToAvoid, List<Link> linksToAvoid) {
         switchQueue.add(rootSwitch);
         int currDistance = 0;
         distanceQueue.add(currDistance);
         switchSearched.put(rootSwitch.getDpid(), currDistance);
+        boolean foundLinkToAvoid = false;
         while (!switchQueue.isEmpty()) {
             Switch sw = switchQueue.poll();
             Switch prevSw = null;
             currDistance = distanceQueue.poll();
             for (Link link : sw.getOutgoingLinks()) {
+                for (Link linkToAvoid: linksToAvoid) {
+                    // TODO: equls should work
+                    //if (link.equals(linkToAvoid)) {
+                    if (linkContains(link, linksToAvoid)) {
+                        foundLinkToAvoid = true;
+                        break;
+                    }
+                }
+                if (foundLinkToAvoid) {
+                    foundLinkToAvoid = false;
+                    continue;
+                }
                 Switch reachedSwitch = link.getDstPort().getSwitch();
-                if (reachedSwitch.getDpid().equals(avoid))
+                if (dpidListToAvoid.contains(reachedSwitch.getDpid().toString()))
                     continue;
                 if ((prevSw != null)
                         && (prevSw.getDpid().equals(reachedSwitch.getDpid())))
@@ -211,6 +232,28 @@
     }
 
 
+    private boolean linkContains(Link link, List<Link> links) {
+
+        Switch srcSwitch1 = link.getSrcSwitch();
+        Switch dstSwitch1 = link.getDstSwitch();
+        long srcPort1 = link.getSrcPort().getPortNumber().value();
+        long dstPort1 = link.getDstPort().getPortNumber().value();
+
+        for (Link link2: links) {
+            Switch srcSwitch2 = link2.getSrcSwitch();
+            Switch dstSwitch2 = link2.getDstSwitch();
+            long srcPort2 = link2.getSrcPort().getPortNumber().value();
+            long dstPort2 = link2.getDstPort().getPortNumber().value();
+
+            if (srcSwitch1.getDpid().toString().equals(srcSwitch2.getDpid().toString())
+             && dstSwitch1.getDpid().toString().equals(dstSwitch2.getDpid().toString())
+             && srcPort1 == srcPort2 && dstPort1 == dstPort2)
+                return true;
+        }
+
+        return false;
+    }
+
     private void getDFSPaths(Dpid dstSwitchDpid, Path path, ArrayList<Path> paths) {
         Dpid rootSwitchDpid = rootSwitch.getDpid();
         for (LinkData upstreamLink : upstreamLinks.get(dstSwitchDpid)) {
diff --git a/src/main/java/net/onrc/onos/apps/segmentrouting/SegmentRoutingManager.java b/src/main/java/net/onrc/onos/apps/segmentrouting/SegmentRoutingManager.java
index c11d604..92aebb3 100644
--- a/src/main/java/net/onrc/onos/apps/segmentrouting/SegmentRoutingManager.java
+++ b/src/main/java/net/onrc/onos/apps/segmentrouting/SegmentRoutingManager.java
@@ -274,8 +274,8 @@
             }
         });
 
-        testMode = TEST_MODE.POLICY_AVOID;
-        testTask.reschedule(20, TimeUnit.SECONDS);
+        //testMode = TEST_MODE.POLICY_AVOID;
+        //testTask.reschedule(20, TimeUnit.SECONDS);
     }
 
     @Override
@@ -1393,15 +1393,20 @@
 
         Switch srcSwitch = mutableTopology.getSwitch(new Dpid(srcNode));
         Switch dstSwitch = mutableTopology.getSwitch(new Dpid(dstNode));
-        Switch swToAvoid =
-                mutableTopology.getSwitch(new Dpid(nodesToAvoid.get(0)));
-        if (srcSwitch == null || dstSwitch == null || swToAvoid == null) {
+        if (srcSwitch == null || dstSwitch == null) {
             log.warn("Switches are not found!");
             return false;
         }
+        List<String> dpidListToAvoid = new ArrayList<String>();
+        for (int nodeId: nodesToAvoid) {
+            Switch swToAvoid =
+                    mutableTopology.getSwitch(new Dpid(nodeId));
+            dpidListToAvoid.add(swToAvoid.getDpid().toString());
+        }
 
         SegmentRoutingPolicy avoidPolicy = new SegmentRoutingPolicyAvoid(this,
-                pid, policyMatch, priority, srcSwitch, dstSwitch, swToAvoid);
+                pid, policyMatch, priority, srcSwitch, dstSwitch,
+                dpidListToAvoid, linksToAvoid);
         if (avoidPolicy.createPolicy()) {
             policyTable.put(pid, avoidPolicy);
             // TODO: handle multi-instance
@@ -2174,6 +2179,7 @@
             printEcmpPaths(ecmpPaths2);
             */
 
+            /*
             String pid = "p1";
             MACAddress srcMac = null;
             MACAddress dstMac = null;
@@ -2191,6 +2197,37 @@
             List<Link> linksToAvoid = null;
             createPolicy(pid, srcMac, dstMac, etherType, srcIp, dstIp, ipProto,
                     srcPort, dstPort, priority, srcNode, dstNode, nodesToAvoid, linksToAvoid);
+            */
+
+            String pid = "p1";
+            MACAddress srcMac = null;
+            MACAddress dstMac = null;
+            Short etherType = Ethernet.TYPE_IPV4;
+            IPv4Net srcIp = new IPv4Net("10.0.1.1/32");
+            IPv4Net dstIp = new IPv4Net("7.7.7.7/32");
+            Byte ipProto = IPv4.PROTOCOL_ICMP;
+            Short srcPort = 0;
+            Short dstPort = 0;
+            int priority = 10000;
+            int srcNode = 1;
+            int dstNode = 6;
+            List<Integer> nodesToAvoid = new ArrayList<Integer>();
+            //nodesToAvoid.add(5);
+            List<Link> linksToAvoid = new ArrayList<Link>();
+
+            Switch sw = mutableTopology.getSwitch(new Dpid(2));
+            Link link = sw.getLinkToNeighbor(new Dpid(5));
+            Switch sw2 = mutableTopology.getSwitch(new Dpid(4));
+            Link link2 = sw2.getLinkToNeighbor(new Dpid(6));
+            Switch sw3 = mutableTopology.getSwitch(new Dpid(1));
+            Link link3 = sw3.getLinkToNeighbor(new Dpid(3));
+            linksToAvoid.add(link);
+            linksToAvoid.add(link2);
+            linksToAvoid.add(link3);
+
+            createPolicy(pid, srcMac, dstMac, etherType, srcIp, dstIp, ipProto,
+                    srcPort, dstPort, priority, srcNode, dstNode, nodesToAvoid, linksToAvoid);
+
 
         }
     }
diff --git a/src/main/java/net/onrc/onos/apps/segmentrouting/SegmentRoutingPolicyAvoid.java b/src/main/java/net/onrc/onos/apps/segmentrouting/SegmentRoutingPolicyAvoid.java
index 478eeb6..608b28f 100644
--- a/src/main/java/net/onrc/onos/apps/segmentrouting/SegmentRoutingPolicyAvoid.java
+++ b/src/main/java/net/onrc/onos/apps/segmentrouting/SegmentRoutingPolicyAvoid.java
@@ -5,8 +5,10 @@
 
 import net.onrc.onos.core.intent.Path;
 import net.onrc.onos.core.matchaction.match.PacketMatch;
+import net.onrc.onos.core.topology.Link;
 import net.onrc.onos.core.topology.LinkData;
 import net.onrc.onos.core.topology.Switch;
+import net.onrc.onos.core.util.Dpid;
 
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
@@ -18,7 +20,8 @@
 
     private Switch srcSwitch;
     private Switch dstSwitch;
-    private Switch switchToAvoid;
+    private List<String> dpidListToAvoid;
+    private List<Link> linkListToAvoid;
 
     public SegmentRoutingPolicyAvoid(PolicyNotification policyNotication) {
         super(policyNotication);
@@ -26,18 +29,21 @@
     }
 
     public SegmentRoutingPolicyAvoid(SegmentRoutingManager srm, String pid,
-            PacketMatch match, int priority, Switch from, Switch to, Switch swToAvoid) {
+            PacketMatch match, int priority, Switch from, Switch to,
+            List<String> dpidList, List<net.onrc.onos.core.topology.Link> linksToAvoid) {
         super(srm, pid, PolicyType.AVOID, match, priority);
         this.srcSwitch = from;
         this.dstSwitch = to;
-        this.switchToAvoid = swToAvoid;
+        this.dpidListToAvoid = dpidList;
+        this.linkListToAvoid = linksToAvoid;
     }
 
     @Override
     public boolean createPolicy() {
 
         //Create a tunnel from srcSwitch to dstSwitch avoiding swToAvoid;
-        ECMPShortestPathGraph graph = new ECMPShortestPathGraph(srcSwitch, switchToAvoid);
+        ECMPShortestPathGraph graph = new ECMPShortestPathGraph(srcSwitch,
+                dpidListToAvoid, linkListToAvoid);
         List<Path> ecmpPaths = graph.getECMPPaths(dstSwitch);
 
         for (Path path: ecmpPaths) {
@@ -49,8 +55,8 @@
             }
             String dstDpid = path.get(0).getDst().getDpid().toString();
             labelStack.add(Integer.valueOf(srManager.getMplsLabel(dstDpid)));
-            String nodeToAvoid = srManager.getMplsLabel(switchToAvoid.getDpid().toString());
-            OptimizeLabelStack(labelStack, switchToAvoid);
+            //String nodeToAvoid = srManager.getMplsLabel(switchToAvoid.getDpid().toString());
+            OptimizeLabelStack(labelStack);
             SegmentRoutingTunnel avoidTunnel = new SegmentRoutingTunnel(
                     srManager, "avoid-0", labelStack);
             if (avoidTunnel.createTunnel()) {
@@ -75,7 +81,7 @@
      *
      * @param labelStack List of label IDs
      */
-    private void OptimizeLabelStack(List<Integer> labelStack, Switch nodeToAvoid) {
+    private void OptimizeLabelStack(List<Integer> labelStack) {
 
         // {101, 103, 104, 106}
         // source = 101
@@ -95,11 +101,13 @@
         Switch nodeToCheck = srManager.getSwitchFromNodeId(
                 labelStack.get(labelStack.size()-i).toString());
         ECMPShortestPathGraph ecmpGraph = new ECMPShortestPathGraph(srcNode);
-        while (!nodeToCheck.getDpid().equals(srcNode)) {
+        while (!nodeToCheck.getDpid().toString().equals(srcNode.getDpid().toString())) {
             List<Path> paths = ecmpGraph.getECMPPaths(nodeToCheck);
             for (Path path: paths) {
                 for (LinkData link: path) {
-                    if (link.getSrc().getDpid().equals(switchToAvoid.getDpid())) {
+                    if (dpidListToAvoid.contains(
+                            link.getSrc().getDpid().toString())
+                            || linkContains(link, linkListToAvoid)) {
                         violated = true;
                         break;
                     }
@@ -121,4 +129,26 @@
             }
         }
     }
+
+    private boolean linkContains(LinkData link, List<Link> links) {
+
+        Dpid srcSwitch1 = link.getSrc().getDpid();
+        Dpid dstSwitch1 = link.getDst().getDpid();
+        long srcPort1 = link.getSrc().getPortNumber().value();
+        long dstPort1 = link.getDst().getPortNumber().value();
+
+        for (Link link2: links) {
+            Switch srcSwitch2 = link2.getSrcSwitch();
+            Switch dstSwitch2 = link2.getDstSwitch();
+            long srcPort2 = link2.getSrcPort().getPortNumber().value();
+            long dstPort2 = link2.getDstPort().getPortNumber().value();
+
+            if (srcSwitch1.toString().equals(srcSwitch2.getDpid().toString())
+             && dstSwitch1.toString().equals(dstSwitch2.getDpid().toString())
+             && srcPort1 == srcPort2 && dstPort1 == dstPort2)
+                return true;
+        }
+
+        return false;
+    }
 }