Complete complex scenario 1

Change-Id: I1c14fd5ffe017523bf32f12efc150711880ed71d
diff --git a/IGMP.py b/IGMP.py
index 1b24de3..db45407 100644
--- a/IGMP.py
+++ b/IGMP.py
@@ -145,7 +145,7 @@
             ip.ttl = 1
             ip.proto = 2
             ip.tos = 0xc0
-            ip.options = [IPOption_Router_Alert()]
+            #ip.options = [IPOption_Router_Alert()]
 
             if igmp.type == IGMP_TYPE_MEMBERSHIP_QUERY:
                 if igmp.gaddr == "0.0.0.0":
diff --git a/olt-complex.py b/olt-complex.py
index bf4486d..4a3e579 100644
--- a/olt-complex.py
+++ b/olt-complex.py
@@ -5,6 +5,8 @@
 import logging
 from oftest.testutils import *
 from oltbase import OltBaseTest
+from oltconstants import *
+
 
 class Thing(object):
     """An object we can stash arbitrary attributes for easy reach"""
@@ -61,13 +63,16 @@
 
         # Some constants
         c_vlan_id = 111
+        mcast_vlan_id = 140
         mcast_groups = Thing(
             ch1="230.10.10.10",
             ch2="231.11.11.11",
             ch3="232.12.12.12",
-            ch4="233.13.13.13"
+            ch4="233.13.13.13",
+            ch5="234.14.14.14"
         )
         onu1 = Thing(
+            port=onu_port,
             ip="13.14.14.13",
             mac="b6:b8:3e:fb:1a:3f",
             s_vlan_id=13
@@ -84,12 +89,28 @@
         self.testSustainedPacketFlow(onu1.s_vlan_id, c_vlan_id, 100)
 
         # 4.  Setup IGMP forwarding toward the controller
+        self.setupIgmpCaptureFlowRules(self.getCookieBlock())
+        self.testSustainedPacketFlow(onu1.s_vlan_id, c_vlan_id, 10) # to check that unicast still flows
+
         # 5.  Send periodic IGMP queries out toward the ONU and verify its arrival
+        self.testIgmpQueryOut()
+        self.testSustainedPacketFlow(onu1.s_vlan_id, c_vlan_id, 10) # to check that unicast still flows
+
         # 6.  Send in an IGMP join request for one channel (specific multicast address) and verify that
         #     the controller receives it.
+        self.sendIgmpReport(join=[mcast_groups.ch1])
+
         # 7.  Setup flows for forwarding multicast traffic from the OLT port toward the ONU(s)
+        ch1_cookie = self.getCookieBlock()
+        ch1_group_id = 1
+        self.setupMcastChannel(mcast_groups.ch1, mcast_vlan_id, [onu1.port], ch1_group_id, ch1_cookie)
+
         # 8.  Verify that multicast packets can reach the ONU
+        self.testMcastFlow(mcast_groups.ch1, mcast_vlan_id, [onu1.port], 10)
+
         # 9.  Verify that bidirectional unicast traffic still works across the PON
+        self.testSustainedPacketFlow(onu1.s_vlan_id, c_vlan_id, 10) # to check that unicast still flows
+
         # 10. Change channel to a new multicast group and verify both multicast receiption as well as
         #     unicast traffic works. This step involves:
         #     - sending a leave message by the ONU to the controller and verify its reception
@@ -99,10 +120,53 @@
         #     - verify that original flow no longer flows
         #     - verify new flow
         #     - verify that unicast still works
+        self.sendIgmpReport(leave=[mcast_groups.ch1], join=[mcast_groups.ch2])
+        self.removeMcastChannel(mcast_groups.ch1, mcast_vlan_id, [onu1.port], ch1_group_id, ch1_cookie)
+        ch2_cookie = self.getCookieBlock()
+        ch2_group_id = 2
+        self.setupMcastChannel(mcast_groups.ch2, mcast_vlan_id, [onu1.port], ch2_group_id, ch2_cookie)
+        self.testMcastFlow(mcast_groups.ch1, mcast_vlan_id, [onu1.port], expect_to_be_blocked=True)
+        self.testMcastFlow(mcast_groups.ch2, mcast_vlan_id, [onu1.port], 10)
+        self.testSustainedPacketFlow(onu1.s_vlan_id, c_vlan_id, 10) # to check that unicast still flows
+
         # 11. Add a second channel while keeping the existing one. Verify all what needs to be verified
         #     (similar as above)
-        # 12. Add two more multicast channels, and verify everything
-        # 13. Flip a channel for one of the multicast channels, and verify everything
+        self.sendIgmpReport(join=[mcast_groups.ch2, mcast_groups.ch3])
+        ch3_cookie = self.getCookieBlock()
+        ch3_group_id = 3
+        self.setupMcastChannel(mcast_groups.ch3, mcast_vlan_id, [onu1.port], ch3_group_id, ch3_cookie)
+        self.testMcastFlow(mcast_groups.ch1, mcast_vlan_id, [onu1.port], expect_to_be_blocked=True)
+        self.testMcastFlow(mcast_groups.ch2, mcast_vlan_id, [onu1.port], 10)
+        self.testMcastFlow(mcast_groups.ch3, mcast_vlan_id, [onu1.port], 10)
+        self.testSustainedPacketFlow(onu1.s_vlan_id, c_vlan_id, 10) # to check that unicast still flows
+
+        # 12. Add two more multicast channels, and verify everything again
+        self.sendIgmpReport(join=[mcast_groups.ch2, mcast_groups.ch3, mcast_groups.ch4, mcast_groups.ch5])
+        ch4_cookie = self.getCookieBlock()
+        ch4_group_id = 4
+        self.setupMcastChannel(mcast_groups.ch4, mcast_vlan_id, [onu1.port], ch4_group_id, ch4_cookie)
+        ch5_cookie = self.getCookieBlock()
+        ch5_group_id = 5
+        self.setupMcastChannel(mcast_groups.ch5, mcast_vlan_id, [onu1.port], ch5_group_id, ch5_cookie)
+        self.testMcastFlow(mcast_groups.ch1, mcast_vlan_id, [onu1.port], expect_to_be_blocked=True)
+        self.testMcastFlow(mcast_groups.ch2, mcast_vlan_id, [onu1.port], 10)
+        self.testMcastFlow(mcast_groups.ch3, mcast_vlan_id, [onu1.port], 10)
+        self.testMcastFlow(mcast_groups.ch4, mcast_vlan_id, [onu1.port], 10)
+        self.testMcastFlow(mcast_groups.ch5, mcast_vlan_id, [onu1.port], 10)
+        self.testSustainedPacketFlow(onu1.s_vlan_id, c_vlan_id, 10) # to check that unicast still flows
+
+        # 13. Flip a channel for one of the multicast channels, and verify everything (ch3 -> ch1)
+        self.sendIgmpReport(join=[mcast_groups.ch1, mcast_groups.ch2, mcast_groups.ch4, mcast_groups.ch5],
+                            leave=[mcast_groups.ch3])
+        self.removeMcastChannel(mcast_groups.ch3, mcast_vlan_id, [onu1.port], ch3_group_id, ch3_cookie)
+        ## reusing the same group id and cookie
+        self.setupMcastChannel(mcast_groups.ch1, mcast_vlan_id, [onu1.port], ch1_group_id, ch1_cookie)
+        self.testMcastFlow(mcast_groups.ch1, mcast_vlan_id, [onu1.port], 10)
+        self.testMcastFlow(mcast_groups.ch2, mcast_vlan_id, [onu1.port], 10)
+        self.testMcastFlow(mcast_groups.ch3, mcast_vlan_id, [onu1.port], expect_to_be_blocked=True)
+        self.testMcastFlow(mcast_groups.ch4, mcast_vlan_id, [onu1.port], 10)
+        self.testMcastFlow(mcast_groups.ch5, mcast_vlan_id, [onu1.port], 10)
+        self.testSustainedPacketFlow(onu1.s_vlan_id, c_vlan_id, 10) # to check that unicast still flows
 
         # 14. Tear down the test.
         self.resetOlt()
diff --git a/olt.py b/olt.py
index 13c997a..42ee55d 100644
--- a/olt.py
+++ b/olt.py
@@ -57,15 +57,10 @@
         match.oxm_list.append(ofp.oxm.eth_type(0x800))
         match.oxm_list.append(ofp.oxm.ip_proto(2))
 
-        igmp = IGMPv3(type=IGMP_TYPE_V3_MEMBERSHIP_REPORT, max_resp_code=30, gaddr="224.0.0.1")
-        igmp.grps = [IGMPv3gr(rtype=IGMP_V3_GR_TYPE_INCLUDE, mcaddr="229.10.20.30")]
-        pkt = IGMPv3.fixup( scapy.Ether(src='00:00:00:00:be:ef', dst='01:00:5e:00:00:01') \
-                            / scapy.IP(src='192.168.0.123', dst='224.0.0.22') / igmp )
-        if len(pkt) < 60:
-                pad_len = 60 - len(pkt)
-                pad = scapy.PAD()
-                pad.load = '\x00' * pad_len
-                pkt = pkt/pad
+        mcast_group = "229.10.20.30"
+        ip_src = '192.168.0.123'
+        ip_dst = '224.0.0.22'
+        pkt = self.buildIgmpReport(ip_src, ip_dst, '00:00:00:00:be:ef', join=[mcast_group], pad_to=60)
 
         self.testPacketIn(match, pkt)
 
@@ -77,7 +72,7 @@
         logging.info("Running IGMP query packet out")
 
         igmp = IGMPv3()  # by default this is a query
-        pkt = buildIgmp(igmp)
+        pkt = self.buildIgmp(igmp)
 
         msg = ofp.message.packet_out(
             in_port=ofp.OFPP_CONTROLLER,
@@ -278,9 +273,6 @@
         verify_packets(self, outPkt, [])
 
 
-
-
-
 class TestGroupAdd(OltBaseTest):
 
     def runTest(self):
@@ -743,7 +735,7 @@
         self.assertTrue(len(stats) == 5, \
                         "Wrong number of rules reports; reported %s, expected 5\n\n %s" % (len(stats), stats))
 
-        self.processEapolRule(onu_port, install = False)
+        self.processEapolRule(onu_port, install=False)
         time.sleep(3)
 
         stats = get_flow_stats(self, ofp.match())
@@ -770,11 +762,11 @@
 
         time.sleep(1)
 
-        self.installDoubleTaggingRules(10, 5, cookie=42, onu = onu_port)
+        self.installDoubleTaggingRules(10, 5, cookie=42, onu=onu_port)
 
         time.sleep(1)
 
-        self.installDoubleTaggingRules(10, 30, cookie=50, onu = onu_port2)
+        self.installDoubleTaggingRules(10, 30, cookie=50, onu=onu_port2)
 
         time.sleep(1)
 
@@ -820,13 +812,3 @@
                 self.installDoubleTaggingRules(stag, ctag)
                 time.sleep(5)
                 self.testPacketFlow(stag, ctag)
-
-
-
-
-
-
-
-
-
-
diff --git a/oltbase.py b/oltbase.py
index 6ff4bc4..1118054 100644
--- a/oltbase.py
+++ b/oltbase.py
@@ -5,6 +5,9 @@
 from oltconstants import *
 import copy
 import logging
+from IGMP import IGMPv3, IGMPv3gr, IGMP_TYPE_MEMBERSHIP_QUERY, \
+     IGMP_TYPE_V3_MEMBERSHIP_REPORT, IGMP_V3_GR_TYPE_EXCLUDE, \
+     IGMP_V3_GR_TYPE_INCLUDE
 
 
 class OltBaseTest(base_tests.SimpleDataPlane):
@@ -20,6 +23,7 @@
     def resetOlt(self):
         """Reset the OLT into a clean healthy state"""
         delete_all_flows(self.controller)
+        delete_all_groups(self.controller)
         do_barrier(self.controller)
         verify_no_errors(self.controller)
 
@@ -53,7 +57,7 @@
             verify_packet_in(self, pkt, of_port, ofp.OFPR_ACTION)
             verify_packets(self, pkt, [])
 
-    def processEapolRule(self, in_port, install = True):
+    def processEapolRule(self, in_port, install=True):
         match = ofp.match()
         match.oxm_list.append(ofp.oxm.eth_type(0x888e))
         match.oxm_list.append(ofp.oxm.in_port(in_port))
@@ -133,7 +137,7 @@
 
     def testSustainedPacketFlow(self, s_vlan_id, c_vlan_id, number_of_roundtrips=10, onu=None):
         for i in xrange(number_of_roundtrips):
-            print "loop %d" % i
+            print "pkt # %d" % (i+1,)
             self.testPacketFlow(s_vlan_id, c_vlan_id, onu=onu, verify_blocked_flows=False)
 
     def installDoubleTaggingRules(self, s_vlan_id, c_vlan_id, cookie=42, onu = None):
@@ -243,3 +247,248 @@
         self.controller.message_send(request)
         do_barrier(self.controller)
         verify_no_errors(self.controller)
+
+    def setupIgmpCaptureFlowRules(self, cookie):
+        """Etsablish flow rules that will forward any incoming
+           IGMP packets to controller (from any port, OLT and ONUs)
+        """
+        match = ofp.match()
+        match.oxm_list.append(ofp.oxm.eth_type(0x800))
+        match.oxm_list.append(ofp.oxm.ip_proto(2))
+        request = ofp.message.flow_add(
+            table_id=test_param_get("table", 0),
+            cookie=42,
+            match=match,
+            instructions=[
+                ofp.instruction.apply_actions(
+                    actions=[
+                        ofp.action.output(
+                            port=ofp.OFPP_CONTROLLER,
+                            max_len=ofp.OFPCML_NO_BUFFER)])],
+            buffer_id=ofp.OFP_NO_BUFFER,
+            priority=2000)
+        logging.info("Inserting flow sending matching packets to controller")
+        self.controller.message_send(request)
+        do_barrier(self.controller)
+
+    def testIgmpQueryOut(self, onu=None):
+        """Send an IGMP Query out to given onu and verify that it arrives"""
+
+        outport = onu_port if onu is None else onu
+
+        igmp = IGMPv3()  # by default this is a query
+        pkt = self.buildIgmp(igmp)
+
+        msg = ofp.message.packet_out(
+            in_port=ofp.OFPP_CONTROLLER,
+            actions=[ofp.action.output(port=outport)],
+            buffer_id=ofp.OFP_NO_BUFFER,
+            data=str(pkt))
+
+        self.controller.message_send(msg)
+
+        rv = self.controller.message_send(msg)
+        self.assertTrue(rv == 0, "Error sending output message")
+        verify_no_errors(self.controller)
+
+        verify_packet(self, pkt, outport)
+
+    def sendIgmpReport(self, join=[], leave=[], onu=None, srcmac=None,
+                     ip_src='1.2.3.4', ip_dst='2440.0.0.22', pad_to=None):
+        """Send an IGMP Join request from given onu to given mcast group"""
+
+        # construct packet
+        in_port = onu_port if onu is None else onu
+        src_mac = '00:00:00:00:be:ef' if srcmac is None else srcmac
+        pkt = str(self.buildIgmpReport(ip_src, ip_dst, src_mac, join, leave, pad_to=pad_to))
+
+        # send packet via in_port
+        self.dataplane.send(in_port, pkt)
+
+        # verify it is received by controller and not by some other port
+        verify_packet_in(self, pkt, in_port, ofp.OFPR_ACTION)
+        verify_packets(self, pkt, [])
+
+    def buildIgmpReport(self, ip_src, ip_dst, srcmac, join=[], leave=[], sources={}, pad_to=None):
+        """ Return an IGMPv4 Membership Report as a scapy Ethernet frame
+            join:    a list of multicast addresses that shall be joined
+            leave:   a list of multicast addresses that shall be left
+            sources: a dict of source lists keyed by the mcast address that needs to be joined
+            or left.
+        """
+
+        assert join or leave, "Either join or leave must be a non-empty list"
+
+        igmp = IGMPv3(type=IGMP_TYPE_V3_MEMBERSHIP_REPORT, max_resp_code=30, gaddr="224.0.0.1")
+
+        for mcast_group in join:
+            srcs = sources.get(mcast_group, [])
+            if len(srcs):
+                gr = IGMPv3gr(rtype=IGMP_V3_GR_TYPE_INCLUDE, mcaddr=mcast_group)
+                gr.sources = srcs
+            else:
+                gr = IGMPv3gr(rtype=IGMP_V3_GR_TYPE_EXCLUDE, mcaddr=mcast_group)
+            igmp.grps.append(gr)
+
+        for mcast_group in leave:
+            gr = IGMPv3gr(rtype=IGMP_V3_GR_TYPE_INCLUDE, mcaddr=mcast_group)
+            igmp.grps.append(gr)
+
+        pkt = IGMPv3.fixup( scapy.Ether(src=srcmac) / scapy.IP() / igmp )
+        pkt = self.padPktTo(pkt, pad_to)
+        return pkt
+
+    def buildIgmp(self, payload, pad_to=None):
+        pkt = IGMPv3.fixup(scapy.Ether() / scapy.IP() / payload)
+        return self.padPktTo(pkt, pad_to)
+
+    def padPktTo(self, pkt, pad_to=None):
+        """If pad_to is provided, it shall be an integer. If pkt is smaller than that number,
+           it will be properly padded (ethernet padding) and returned. Otherwise the original
+           pkt is returned.
+        """
+        if pad_to is None:
+            return pkt
+
+        pad_len = pad_to - len(pkt)
+        if pad_len > 0:
+            pad = scapy.scapy.layers.l2.Padding()
+            pad.load = '\x00' * pad_len
+            pkt = pkt / pad
+        return pkt
+
+    def setupMcastChannel(self, mcast_addr, mcast_vlan_id, port_list, group_id=1, cookie=100):
+        """Setup multicast forwarding for the mcast address using mcast vlan id
+           and given port list.
+        """
+
+        # setup group first with given port list
+
+        buckets = [
+            ofp.common.bucket(
+                watch_port=ofp.OFPP_ANY,
+                watch_group=ofp.OFPG_ANY,
+                actions=[
+                    ofp.action.pop_vlan(),
+                    ofp.action.output(port=port)
+                ])
+            for port in port_list
+        ]
+
+        msg = ofp.message.group_add(
+            group_type=ofp.OFPGT_ALL,
+            group_id=group_id,
+            buckets=buckets
+        )
+
+        self.controller.message_send(msg)
+        do_barrier(self.controller)
+        verify_no_errors(self.controller)
+
+        # Then setup flow rule pointing to group
+
+        match = ofp.match()
+        match.oxm_list.append(ofp.oxm.in_port(olt_port))
+        match.oxm_list.append(ofp.oxm.vlan_vid(ofp.OFPVID_PRESENT | mcast_vlan_id))
+        match.oxm_list.append(ofp.oxm.eth_type(0x800))
+        match.oxm_list.append(ofp.oxm.ipv4_dst(self.ip2int(mcast_addr)))
+        request = ofp.message.flow_add(
+            table_id=test_param_get("table", 0),
+            cookie=cookie,
+            match=match,
+            instructions=[
+                ofp.instruction.apply_actions(
+                    actions=[ofp.action.group(group_id=group_id)]),
+            ],
+            buffer_id=ofp.OFP_NO_BUFFER, priority=1000)
+        self.controller.message_send(request)
+        do_barrier(self.controller)
+        verify_no_errors(self.controller)
+
+    def removeMcastChannel(self, mcast_addr, mcast_vlan_id, port_list, group_id, cookie):
+        """Remove mumticast forwarding for given mcast address"""
+
+        # Remove flow first
+        match = ofp.match()
+        match.oxm_list.append(ofp.oxm.in_port(olt_port))
+        match.oxm_list.append(ofp.oxm.vlan_vid(ofp.OFPVID_PRESENT | mcast_vlan_id))
+        match.oxm_list.append(ofp.oxm.eth_type(0x800))
+        match.oxm_list.append(ofp.oxm.ipv4_dst(self.ip2int(mcast_addr)))
+        request = ofp.message.flow_delete(
+            table_id=test_param_get("table", 0),
+            cookie=cookie,
+            match=match,
+            instructions=[
+                ofp.instruction.apply_actions(
+                    actions=[ofp.action.group(group_id=group_id)]),
+            ],
+            buffer_id=ofp.OFP_NO_BUFFER, priority=1000)
+        self.controller.message_send(request)
+        do_barrier(self.controller)
+        verify_no_errors(self.controller)
+
+        # Then remove the group
+        group_delete = ofp.message.group_delete(group_id=group_id)
+        self.controller.message_send(group_delete)
+        do_barrier(self.controller)
+        verify_no_errors(self.controller)
+
+    def ip2int(self, ip):
+        """Convert a dot-notated string IP address"""
+        digits = [int(d) for d in ip.split('.')]
+        assert len(digits) == 4
+        val = (
+            (digits[0] & 0xff) << 24 |
+            (digits[1] & 0xff) << 16 |
+            (digits[2] & 0xff) << 8  |
+            (digits[3] & 0xff))
+        return val
+
+    def mcastIp2McastMac(self, ip):
+        """ Convert a dot-notated IPv4 multicast address string into an multicast MAC address"""
+        digits = [int(d) for d in ip.split('.')]
+        return '01:00:5e:%02x:%02x:%02x' % (digits[1] & 0x7f, digits[2] & 0xff, digits[3] & 0xff)
+
+    def testMcastFlow(self, mcast_addr, mcast_vlan_id, ports=None, numpkt=1, ip_src="66.77.88.99",
+                      expect_to_be_blocked=False):
+        """ Send given number of mcast packets using mcast address and vlan_id
+            and check they arrive to given port(s).
+        """
+
+        # construct mcast packet
+        pktlen = 250
+        pktToSend = simple_udp_packet(
+            eth_dst=self.mcastIp2McastMac(mcast_addr),
+            ip_src=ip_src,
+            ip_dst=mcast_addr,
+            pktlen=pktlen+4,
+            dl_vlan_enable=True,
+            vlan_vid=mcast_vlan_id,
+            vlan_pcp=0)
+
+        if device_type == "pmc":
+            pktToReceive = simple_udp_packet(
+                eth_dst=self.mcastIp2McastMac(mcast_addr),
+                ip_src=ip_src,
+                ip_dst=mcast_addr,
+                pktlen=pktlen+4,
+                dl_vlan_enable=True,
+                vlan_vid=0,
+                vlan_pcp=0)
+        else:
+            pktToReceive = simple_udp_packet(
+                eth_dst=self.mcastIp2McastMac(mcast_addr),
+                ip_src=ip_src,
+                ip_dst=mcast_addr,
+                pktlen=pktlen)
+
+        # send mcast packet to olt
+        self.dataplane.send(olt_port, str(pktToSend))
+
+        # test that packet is received on each designated port
+        ports = [onu_port] if ports is None else ports
+        for port in ports:
+            if expect_to_be_blocked:
+                verify_no_packet(self, pktToReceive, port)
+            else:
+                verify_packet(self, pktToReceive, port)
diff --git a/oltconstants.py b/oltconstants.py
index f00fb74..b97ab5a 100644
--- a/oltconstants.py
+++ b/oltconstants.py
@@ -37,15 +37,6 @@
 
     return group_mod
 
-def buildIgmp(payload):
-    pkt = pkt = IGMPv3.fixup(scapy.Ether() / scapy.IP() / payload)
-    if len(pkt) < 60:
-        pad_len = 60 - len(pkt)
-        pad = scapy.scapy.layers.l2.Padding()
-        pad.load = '\x00' * pad_len
-        pkt = pkt / pad
-    return pkt
-
 def double_vlan_udp_packet(pktlen=100,
                            eth_dst='00:01:02:03:04:05',
                            eth_src='00:06:07:08:09:0a',