Complete complex scenario 1

Change-Id: I1c14fd5ffe017523bf32f12efc150711880ed71d
diff --git a/oltbase.py b/oltbase.py
index 6ff4bc4..1118054 100644
--- a/oltbase.py
+++ b/oltbase.py
@@ -5,6 +5,9 @@
 from oltconstants import *
 import copy
 import logging
+from IGMP import IGMPv3, IGMPv3gr, IGMP_TYPE_MEMBERSHIP_QUERY, \
+     IGMP_TYPE_V3_MEMBERSHIP_REPORT, IGMP_V3_GR_TYPE_EXCLUDE, \
+     IGMP_V3_GR_TYPE_INCLUDE
 
 
 class OltBaseTest(base_tests.SimpleDataPlane):
@@ -20,6 +23,7 @@
     def resetOlt(self):
         """Reset the OLT into a clean healthy state"""
         delete_all_flows(self.controller)
+        delete_all_groups(self.controller)
         do_barrier(self.controller)
         verify_no_errors(self.controller)
 
@@ -53,7 +57,7 @@
             verify_packet_in(self, pkt, of_port, ofp.OFPR_ACTION)
             verify_packets(self, pkt, [])
 
-    def processEapolRule(self, in_port, install = True):
+    def processEapolRule(self, in_port, install=True):
         match = ofp.match()
         match.oxm_list.append(ofp.oxm.eth_type(0x888e))
         match.oxm_list.append(ofp.oxm.in_port(in_port))
@@ -133,7 +137,7 @@
 
     def testSustainedPacketFlow(self, s_vlan_id, c_vlan_id, number_of_roundtrips=10, onu=None):
         for i in xrange(number_of_roundtrips):
-            print "loop %d" % i
+            print "pkt # %d" % (i+1,)
             self.testPacketFlow(s_vlan_id, c_vlan_id, onu=onu, verify_blocked_flows=False)
 
     def installDoubleTaggingRules(self, s_vlan_id, c_vlan_id, cookie=42, onu = None):
@@ -243,3 +247,248 @@
         self.controller.message_send(request)
         do_barrier(self.controller)
         verify_no_errors(self.controller)
+
+    def setupIgmpCaptureFlowRules(self, cookie):
+        """Etsablish flow rules that will forward any incoming
+           IGMP packets to controller (from any port, OLT and ONUs)
+        """
+        match = ofp.match()
+        match.oxm_list.append(ofp.oxm.eth_type(0x800))
+        match.oxm_list.append(ofp.oxm.ip_proto(2))
+        request = ofp.message.flow_add(
+            table_id=test_param_get("table", 0),
+            cookie=42,
+            match=match,
+            instructions=[
+                ofp.instruction.apply_actions(
+                    actions=[
+                        ofp.action.output(
+                            port=ofp.OFPP_CONTROLLER,
+                            max_len=ofp.OFPCML_NO_BUFFER)])],
+            buffer_id=ofp.OFP_NO_BUFFER,
+            priority=2000)
+        logging.info("Inserting flow sending matching packets to controller")
+        self.controller.message_send(request)
+        do_barrier(self.controller)
+
+    def testIgmpQueryOut(self, onu=None):
+        """Send an IGMP Query out to given onu and verify that it arrives"""
+
+        outport = onu_port if onu is None else onu
+
+        igmp = IGMPv3()  # by default this is a query
+        pkt = self.buildIgmp(igmp)
+
+        msg = ofp.message.packet_out(
+            in_port=ofp.OFPP_CONTROLLER,
+            actions=[ofp.action.output(port=outport)],
+            buffer_id=ofp.OFP_NO_BUFFER,
+            data=str(pkt))
+
+        self.controller.message_send(msg)
+
+        rv = self.controller.message_send(msg)
+        self.assertTrue(rv == 0, "Error sending output message")
+        verify_no_errors(self.controller)
+
+        verify_packet(self, pkt, outport)
+
+    def sendIgmpReport(self, join=[], leave=[], onu=None, srcmac=None,
+                     ip_src='1.2.3.4', ip_dst='2440.0.0.22', pad_to=None):
+        """Send an IGMP Join request from given onu to given mcast group"""
+
+        # construct packet
+        in_port = onu_port if onu is None else onu
+        src_mac = '00:00:00:00:be:ef' if srcmac is None else srcmac
+        pkt = str(self.buildIgmpReport(ip_src, ip_dst, src_mac, join, leave, pad_to=pad_to))
+
+        # send packet via in_port
+        self.dataplane.send(in_port, pkt)
+
+        # verify it is received by controller and not by some other port
+        verify_packet_in(self, pkt, in_port, ofp.OFPR_ACTION)
+        verify_packets(self, pkt, [])
+
+    def buildIgmpReport(self, ip_src, ip_dst, srcmac, join=[], leave=[], sources={}, pad_to=None):
+        """ Return an IGMPv4 Membership Report as a scapy Ethernet frame
+            join:    a list of multicast addresses that shall be joined
+            leave:   a list of multicast addresses that shall be left
+            sources: a dict of source lists keyed by the mcast address that needs to be joined
+            or left.
+        """
+
+        assert join or leave, "Either join or leave must be a non-empty list"
+
+        igmp = IGMPv3(type=IGMP_TYPE_V3_MEMBERSHIP_REPORT, max_resp_code=30, gaddr="224.0.0.1")
+
+        for mcast_group in join:
+            srcs = sources.get(mcast_group, [])
+            if len(srcs):
+                gr = IGMPv3gr(rtype=IGMP_V3_GR_TYPE_INCLUDE, mcaddr=mcast_group)
+                gr.sources = srcs
+            else:
+                gr = IGMPv3gr(rtype=IGMP_V3_GR_TYPE_EXCLUDE, mcaddr=mcast_group)
+            igmp.grps.append(gr)
+
+        for mcast_group in leave:
+            gr = IGMPv3gr(rtype=IGMP_V3_GR_TYPE_INCLUDE, mcaddr=mcast_group)
+            igmp.grps.append(gr)
+
+        pkt = IGMPv3.fixup( scapy.Ether(src=srcmac) / scapy.IP() / igmp )
+        pkt = self.padPktTo(pkt, pad_to)
+        return pkt
+
+    def buildIgmp(self, payload, pad_to=None):
+        pkt = IGMPv3.fixup(scapy.Ether() / scapy.IP() / payload)
+        return self.padPktTo(pkt, pad_to)
+
+    def padPktTo(self, pkt, pad_to=None):
+        """If pad_to is provided, it shall be an integer. If pkt is smaller than that number,
+           it will be properly padded (ethernet padding) and returned. Otherwise the original
+           pkt is returned.
+        """
+        if pad_to is None:
+            return pkt
+
+        pad_len = pad_to - len(pkt)
+        if pad_len > 0:
+            pad = scapy.scapy.layers.l2.Padding()
+            pad.load = '\x00' * pad_len
+            pkt = pkt / pad
+        return pkt
+
+    def setupMcastChannel(self, mcast_addr, mcast_vlan_id, port_list, group_id=1, cookie=100):
+        """Setup multicast forwarding for the mcast address using mcast vlan id
+           and given port list.
+        """
+
+        # setup group first with given port list
+
+        buckets = [
+            ofp.common.bucket(
+                watch_port=ofp.OFPP_ANY,
+                watch_group=ofp.OFPG_ANY,
+                actions=[
+                    ofp.action.pop_vlan(),
+                    ofp.action.output(port=port)
+                ])
+            for port in port_list
+        ]
+
+        msg = ofp.message.group_add(
+            group_type=ofp.OFPGT_ALL,
+            group_id=group_id,
+            buckets=buckets
+        )
+
+        self.controller.message_send(msg)
+        do_barrier(self.controller)
+        verify_no_errors(self.controller)
+
+        # Then setup flow rule pointing to group
+
+        match = ofp.match()
+        match.oxm_list.append(ofp.oxm.in_port(olt_port))
+        match.oxm_list.append(ofp.oxm.vlan_vid(ofp.OFPVID_PRESENT | mcast_vlan_id))
+        match.oxm_list.append(ofp.oxm.eth_type(0x800))
+        match.oxm_list.append(ofp.oxm.ipv4_dst(self.ip2int(mcast_addr)))
+        request = ofp.message.flow_add(
+            table_id=test_param_get("table", 0),
+            cookie=cookie,
+            match=match,
+            instructions=[
+                ofp.instruction.apply_actions(
+                    actions=[ofp.action.group(group_id=group_id)]),
+            ],
+            buffer_id=ofp.OFP_NO_BUFFER, priority=1000)
+        self.controller.message_send(request)
+        do_barrier(self.controller)
+        verify_no_errors(self.controller)
+
+    def removeMcastChannel(self, mcast_addr, mcast_vlan_id, port_list, group_id, cookie):
+        """Remove mumticast forwarding for given mcast address"""
+
+        # Remove flow first
+        match = ofp.match()
+        match.oxm_list.append(ofp.oxm.in_port(olt_port))
+        match.oxm_list.append(ofp.oxm.vlan_vid(ofp.OFPVID_PRESENT | mcast_vlan_id))
+        match.oxm_list.append(ofp.oxm.eth_type(0x800))
+        match.oxm_list.append(ofp.oxm.ipv4_dst(self.ip2int(mcast_addr)))
+        request = ofp.message.flow_delete(
+            table_id=test_param_get("table", 0),
+            cookie=cookie,
+            match=match,
+            instructions=[
+                ofp.instruction.apply_actions(
+                    actions=[ofp.action.group(group_id=group_id)]),
+            ],
+            buffer_id=ofp.OFP_NO_BUFFER, priority=1000)
+        self.controller.message_send(request)
+        do_barrier(self.controller)
+        verify_no_errors(self.controller)
+
+        # Then remove the group
+        group_delete = ofp.message.group_delete(group_id=group_id)
+        self.controller.message_send(group_delete)
+        do_barrier(self.controller)
+        verify_no_errors(self.controller)
+
+    def ip2int(self, ip):
+        """Convert a dot-notated string IP address"""
+        digits = [int(d) for d in ip.split('.')]
+        assert len(digits) == 4
+        val = (
+            (digits[0] & 0xff) << 24 |
+            (digits[1] & 0xff) << 16 |
+            (digits[2] & 0xff) << 8  |
+            (digits[3] & 0xff))
+        return val
+
+    def mcastIp2McastMac(self, ip):
+        """ Convert a dot-notated IPv4 multicast address string into an multicast MAC address"""
+        digits = [int(d) for d in ip.split('.')]
+        return '01:00:5e:%02x:%02x:%02x' % (digits[1] & 0x7f, digits[2] & 0xff, digits[3] & 0xff)
+
+    def testMcastFlow(self, mcast_addr, mcast_vlan_id, ports=None, numpkt=1, ip_src="66.77.88.99",
+                      expect_to_be_blocked=False):
+        """ Send given number of mcast packets using mcast address and vlan_id
+            and check they arrive to given port(s).
+        """
+
+        # construct mcast packet
+        pktlen = 250
+        pktToSend = simple_udp_packet(
+            eth_dst=self.mcastIp2McastMac(mcast_addr),
+            ip_src=ip_src,
+            ip_dst=mcast_addr,
+            pktlen=pktlen+4,
+            dl_vlan_enable=True,
+            vlan_vid=mcast_vlan_id,
+            vlan_pcp=0)
+
+        if device_type == "pmc":
+            pktToReceive = simple_udp_packet(
+                eth_dst=self.mcastIp2McastMac(mcast_addr),
+                ip_src=ip_src,
+                ip_dst=mcast_addr,
+                pktlen=pktlen+4,
+                dl_vlan_enable=True,
+                vlan_vid=0,
+                vlan_pcp=0)
+        else:
+            pktToReceive = simple_udp_packet(
+                eth_dst=self.mcastIp2McastMac(mcast_addr),
+                ip_src=ip_src,
+                ip_dst=mcast_addr,
+                pktlen=pktlen)
+
+        # send mcast packet to olt
+        self.dataplane.send(olt_port, str(pktToSend))
+
+        # test that packet is received on each designated port
+        ports = [onu_port] if ports is None else ports
+        for port in ports:
+            if expect_to_be_blocked:
+                verify_no_packet(self, pktToReceive, port)
+            else:
+                verify_packet(self, pktToReceive, port)