Enforce to use unique group Id for k8s svc IP, port, proto combination

Change-Id: I6cad3b9ffac86ee0818e6317443c80f5791a9e74
diff --git a/apps/k8s-networking/api/src/main/java/org/onosproject/k8snetworking/api/Constants.java b/apps/k8s-networking/api/src/main/java/org/onosproject/k8snetworking/api/Constants.java
index 076d5d6..50170fa 100644
--- a/apps/k8s-networking/api/src/main/java/org/onosproject/k8snetworking/api/Constants.java
+++ b/apps/k8s-networking/api/src/main/java/org/onosproject/k8snetworking/api/Constants.java
@@ -41,6 +41,9 @@
     public static final String SHIFTED_IP_CIDR = "172.10.0.0/16";
     public static final String SHIFTED_IP_PREFIX = "172.10";
 
+    public static final String SRC = "src";
+    public static final String DST = "dst";
+
     // TODO: need to inject service IP CIDR through REST
     public static final String SERVICE_IP_CIDR = "10.96.0.0/24";
 
diff --git a/apps/k8s-networking/app/src/main/java/org/onosproject/k8snetworking/impl/K8sServiceHandler.java b/apps/k8s-networking/app/src/main/java/org/onosproject/k8snetworking/impl/K8sServiceHandler.java
index 555a144..f2e56e0 100644
--- a/apps/k8s-networking/app/src/main/java/org/onosproject/k8snetworking/impl/K8sServiceHandler.java
+++ b/apps/k8s-networking/app/src/main/java/org/onosproject/k8snetworking/impl/K8sServiceHandler.java
@@ -17,12 +17,12 @@
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
-import com.google.common.collect.Sets;
 import io.fabric8.kubernetes.api.model.EndpointAddress;
 import io.fabric8.kubernetes.api.model.EndpointPort;
 import io.fabric8.kubernetes.api.model.EndpointSubset;
 import io.fabric8.kubernetes.api.model.Endpoints;
 import io.fabric8.kubernetes.api.model.Service;
+import io.fabric8.kubernetes.api.model.ServicePort;
 import org.onlab.packet.Ethernet;
 import org.onlab.packet.IPv4;
 import org.onlab.packet.IpAddress;
@@ -40,7 +40,6 @@
 import org.onosproject.k8snetworking.api.K8sFlowRuleService;
 import org.onosproject.k8snetworking.api.K8sGroupRuleService;
 import org.onosproject.k8snetworking.api.K8sNetworkService;
-import org.onosproject.k8snetworking.api.K8sPodService;
 import org.onosproject.k8snetworking.api.K8sServiceEvent;
 import org.onosproject.k8snetworking.api.K8sServiceListener;
 import org.onosproject.k8snetworking.api.K8sServiceService;
@@ -60,7 +59,10 @@
 import org.onosproject.net.flow.criteria.ExtensionSelector;
 import org.onosproject.net.flow.instructions.ExtensionTreatment;
 import org.onosproject.net.group.GroupBucket;
+import org.onosproject.store.serializers.KryoNamespaces;
 import org.onosproject.store.service.AtomicCounter;
+import org.onosproject.store.service.ConsistentMap;
+import org.onosproject.store.service.Serializer;
 import org.onosproject.store.service.StorageService;
 import org.osgi.service.component.ComponentContext;
 import org.osgi.service.component.annotations.Activate;
@@ -81,6 +83,7 @@
 
 import static java.util.concurrent.Executors.newSingleThreadExecutor;
 import static org.onlab.util.Tools.groupedThreads;
+import static org.onosproject.k8snetworking.api.Constants.DST;
 import static org.onosproject.k8snetworking.api.Constants.JUMP_TABLE;
 import static org.onosproject.k8snetworking.api.Constants.K8S_NETWORKING_APP_ID;
 import static org.onosproject.k8snetworking.api.Constants.NAT_STATEFUL;
@@ -94,6 +97,7 @@
 import static org.onosproject.k8snetworking.api.Constants.SERVICE_TABLE;
 import static org.onosproject.k8snetworking.api.Constants.SHIFTED_IP_CIDR;
 import static org.onosproject.k8snetworking.api.Constants.SHIFTED_IP_PREFIX;
+import static org.onosproject.k8snetworking.api.Constants.SRC;
 import static org.onosproject.k8snetworking.impl.OsgiPropertyConstants.SERVICE_IP_NAT_MODE;
 import static org.onosproject.k8snetworking.impl.OsgiPropertyConstants.SERVICE_IP_NAT_MODE_DEFAULT;
 import static org.onosproject.k8snetworking.util.K8sNetworkingUtil.nodeIpGatewayIpMap;
@@ -125,6 +129,7 @@
     private static final String NONE = "None";
     private static final String CLUSTER_IP = "ClusterIP";
     private static final String TCP = "TCP";
+    private static final String UDP = "UDP";
 
     private static final String GROUP_ID_COUNTER_NAME = "group-id-counter";
 
@@ -167,9 +172,6 @@
     @Reference(cardinality = ReferenceCardinality.MANDATORY)
     protected K8sServiceService k8sServiceService;
 
-    @Reference(cardinality = ReferenceCardinality.MANDATORY)
-    protected K8sPodService k8sPodService;
-
     /** Service IP address translation mode. */
     private String serviceIpNatMode = SERVICE_IP_NAT_MODE_DEFAULT;
 
@@ -182,6 +184,9 @@
 
     private AtomicCounter groupIdCounter;
 
+    // service IP ports has following format IP_PORT_PROTO
+    private ConsistentMap<String, Integer> servicePortGroupIdMap;
+
     private ApplicationId appId;
     private NodeId localNodeId;
 
@@ -196,6 +201,12 @@
 
         groupIdCounter = storageService.getAtomicCounter(GROUP_ID_COUNTER_NAME);
 
+        servicePortGroupIdMap = storageService.<String, Integer>consistentMapBuilder()
+                .withName("k8s-service-ip-port-set")
+                .withSerializer(Serializer.using(KryoNamespaces.API))
+                .withApplicationId(appId)
+                .build();
+
         log.info("Started");
     }
 
@@ -287,83 +298,131 @@
                 install);
     }
 
-    private void setStatelessGroupFlowRules(DeviceId deviceId,
-                                            Service service, boolean install) {
-        int groupId = (int) groupIdCounter.incrementAndGet();
+    private String servicePortStr(String ip, int port, String protocol) {
+        return ip + "_" + port + "_" + protocol;
+    }
 
-        List<GroupBucket> buckets = Lists.newArrayList();
+    /**
+     * Obtains the service port to endpoint address paired map.
+     *
+     * @param service   kubernetes service
+     * @return a map where key is kubernetes service port, and value is the
+     * endpoint addresses that are associated with the service port
+     */
+    private Map<ServicePort, Set<String>> getSportEpAddressMap(Service service) {
+
+        Map<ServicePort, Set<String>> map = Maps.newConcurrentMap();
 
         String serviceName = service.getMetadata().getName();
-
         List<Endpoints> endpointses = k8sEndpointsService.endpointses()
                 .stream()
                 .filter(ep -> serviceName.equals(ep.getMetadata().getName()))
                 .collect(Collectors.toList());
 
-        Map<String, String> nodeIpGatewayIpMap =
-                nodeIpGatewayIpMap(k8sNodeService, k8sNetworkService);
+        service.getSpec().getPorts().stream()
+                .filter(Objects::nonNull)
+                .filter(sp -> sp.getTargetPort() != null)
+                .filter(sp -> sp.getTargetPort().getIntVal() != null)
+                .forEach(sp -> {
+            int targetPort = sp.getTargetPort().getIntVal();
+            String targetProtocol = sp.getProtocol();
 
-        Map<String, Set<Integer>> podIpPorts = Maps.newConcurrentMap();
-
-        for (Endpoints endpoints : endpointses) {
-            for (EndpointSubset endpointSubset : endpoints.getSubsets()) {
-                List<EndpointPort> ports = endpointSubset.getPorts()
-                        .stream()
-                        .filter(p -> p.getProtocol().equals(TCP))
-                        .collect(Collectors.toList());
-
-                for (EndpointAddress address : endpointSubset.getAddresses()) {
-                    String podIp = nodeIpGatewayIpMap.containsKey(address.getIp()) ?
-                            nodeIpGatewayIpMap.get(address.getIp()) : address.getIp();
-
-                    ports.forEach(p -> {
-                        ExtensionTreatment resubmitTreatment = buildResubmitExtension(
-                                deviceService.getDevice(deviceId), ROUTING_TABLE);
-                        TrafficTreatment treatment = DefaultTrafficTreatment.builder()
-                                .setIpDst(IpAddress.valueOf(podIp))
-                                .setTcpDst(TpPort.tpPort(p.getPort()))
-                                .extension(resubmitTreatment, deviceId)
-                                .build();
-                        buckets.add(buildGroupBucket(treatment, SELECT, (short) -1));
-
-                        Set<Integer> existPorts = podIpPorts.get(podIp);
-                        if (existPorts == null || existPorts.isEmpty()) {
-                            existPorts = Sets.newConcurrentHashSet();
+            for (Endpoints endpoints : endpointses) {
+                for (EndpointSubset endpointSubset : endpoints.getSubsets()) {
+                    for (EndpointPort endpointPort : endpointSubset.getPorts()) {
+                        if (targetProtocol.equals(endpointPort.getProtocol()) &&
+                                targetPort == endpointPort.getPort()) {
+                            Set<String> addresses = endpointSubset.getAddresses()
+                                    .stream().map(EndpointAddress::getIp)
+                                    .collect(Collectors.toSet());
+                            map.put(sp, addresses);
                         }
-                        existPorts.add(p.getPort());
-                        podIpPorts.put(podIp, existPorts);
-                    });
+                    }
                 }
             }
-        }
+        });
 
-        if (!buckets.isEmpty()) {
-            k8sGroupRuleService.setRule(appId, deviceId, groupId, SELECT, buckets, install);
+        return map;
+    }
+
+    private void setStatelessGroupFlowRules(DeviceId deviceId, Service service,
+                                            boolean install) {
+        Map<ServicePort, Set<String>> spEpasMap = getSportEpAddressMap(service);
+        Map<String, String> nodeIpGatewayIpMap =
+                nodeIpGatewayIpMap(k8sNodeService, k8sNetworkService);
+        Map<ServicePort, List<GroupBucket>> spGrpBkts = Maps.newConcurrentMap();
+
+        spEpasMap.forEach((sp, epas) -> {
+            List<GroupBucket> bkts = Lists.newArrayList();
+            epas.forEach(epa -> {
+                String podIp = nodeIpGatewayIpMap.getOrDefault(epa, epa);
+                ExtensionTreatment resubmitTreatment = buildResubmitExtension(
+                        deviceService.getDevice(deviceId), ROUTING_TABLE);
+                TrafficTreatment.Builder tBuilder = DefaultTrafficTreatment.builder()
+                        .setIpDst(IpAddress.valueOf(podIp))
+                        .extension(resubmitTreatment, deviceId);
+
+                if (TCP.equals(sp.getProtocol())) {
+                    tBuilder.setTcpDst(TpPort.tpPort(sp.getTargetPort().getIntVal()));
+                } else if (UDP.equals(sp.getProtocol())) {
+                    tBuilder.setUdpDst(TpPort.tpPort(sp.getTargetPort().getIntVal()));
+                }
+
+                bkts.add(buildGroupBucket(tBuilder.build(), SELECT, (short) -1));
+            });
+            spGrpBkts.put(sp, bkts);
+        });
+
+        String serviceIp = service.getSpec().getClusterIP();
+        spGrpBkts.forEach((sp, bkts) -> {
+            String svcStr = servicePortStr(serviceIp, sp.getPort(), sp.getProtocol());
+            int groupId;
+
+            if (servicePortGroupIdMap.asJavaMap().containsKey(svcStr)) {
+                groupId = servicePortGroupIdMap.asJavaMap().get(svcStr);
+            } else {
+                groupId = (int) groupIdCounter.incrementAndGet();
+                servicePortGroupIdMap.put(svcStr, groupId);
+            }
+
+            // add group table rules
+            k8sGroupRuleService.setRule(appId, deviceId, groupId,
+                    SELECT, bkts, install);
+
+            // add flow rules for shifting IP domain
             setShiftDomainRules(deviceId, SERVICE_TABLE, groupId,
-                    PRIORITY_NAT_RULE, service, install);
+                    PRIORITY_NAT_RULE, serviceIp, sp.getPort(),
+                    sp.getProtocol(), install);
+        });
 
-            podIpPorts.forEach((k, v) ->
-                v.forEach(p -> setUnshiftDomainRules(deviceId, POD_TABLE,
-                        PRIORITY_NAT_RULE, service, k, p, install)));
-        }
+        spEpasMap.forEach((sp, epas) ->
+            // add flow rules for unshifting IP domain
+            epas.forEach(epa -> {
+                String podIp = nodeIpGatewayIpMap.getOrDefault(epa, epa);
+                setUnshiftDomainRules(deviceId, POD_TABLE,
+                PRIORITY_NAT_RULE, serviceIp, sp.getPort(), sp.getProtocol(),
+                        podIp, sp.getTargetPort().getIntVal(), install);
+            }
+        ));
     }
 
     private void setShiftDomainRules(DeviceId deviceId, int installTable,
-                                     int groupId, int priority,
-                                     Service service, boolean install) {
-        String serviceIp = service.getSpec().getClusterIP();
-        // TODO: multi-ports case should be addressed
-        Integer servicePort = service.getSpec().getPorts().get(0).getPort();
-
-        TrafficSelector selector = DefaultTrafficSelector.builder()
+                                     int groupId, int priority, String serviceIp,
+                                     int servicePort, String protocol, boolean install) {
+        TrafficSelector.Builder sBuilder = DefaultTrafficSelector.builder()
                 .matchEthType(Ethernet.TYPE_IPV4)
-                .matchIPProtocol(IPv4.PROTOCOL_TCP)
-                .matchIPDst(IpPrefix.valueOf(IpAddress.valueOf(serviceIp), 32))
-                .matchTcpDst(TpPort.tpPort(servicePort))
-                .build();
+                .matchIPDst(IpPrefix.valueOf(IpAddress.valueOf(serviceIp), HOST_CIDR_NUM));
+
+        if (TCP.equals(protocol)) {
+            sBuilder.matchIPProtocol(IPv4.PROTOCOL_TCP)
+                    .matchTcpDst(TpPort.tpPort(servicePort));
+        } else if (UDP.equals(protocol)) {
+            sBuilder.matchIPProtocol(IPv4.PROTOCOL_UDP)
+                    .matchUdpDst(TpPort.tpPort(servicePort));
+        }
 
         ExtensionTreatment loadTreatment = buildLoadExtension(
-                deviceService.getDevice(deviceId), "src", SHIFTED_IP_PREFIX);
+                deviceService.getDevice(deviceId), SRC, SHIFTED_IP_PREFIX);
 
         TrafficTreatment treatment = DefaultTrafficTreatment.builder()
                 .extension(loadTreatment, deviceId)
@@ -373,7 +432,7 @@
         k8sFlowRuleService.setRule(
                 appId,
                 deviceId,
-                selector,
+                sBuilder.build(),
                 treatment,
                 priority,
                 installTable,
@@ -381,34 +440,43 @@
     }
 
     private void setUnshiftDomainRules(DeviceId deviceId, int installTable,
-                                       int priority, Service service, String podIp,
-                                       int podPort, boolean install) {
-        String serviceIp = service.getSpec().getClusterIP();
-        // TODO: multi-ports case should be addressed
-        Integer servicePort = service.getSpec().getPorts().get(0).getPort();
-
-        TrafficSelector selector = DefaultTrafficSelector.builder()
+                                       int priority, String serviceIp,
+                                       int servicePort, String protocol,
+                                       String podIp, int podPort, boolean install) {
+        TrafficSelector.Builder sBuilder = DefaultTrafficSelector.builder()
                 .matchEthType(Ethernet.TYPE_IPV4)
-                .matchIPProtocol(IPv4.PROTOCOL_TCP)
-                .matchIPSrc(IpPrefix.valueOf(IpAddress.valueOf(podIp), 32))
-                .matchTcpSrc(TpPort.tpPort(podPort))
-                .build();
+                .matchIPSrc(IpPrefix.valueOf(IpAddress.valueOf(podIp), HOST_CIDR_NUM));
+
+        if (TCP.equals(protocol)) {
+            sBuilder.matchIPProtocol(IPv4.PROTOCOL_TCP)
+                    .matchTcpSrc(TpPort.tpPort(podPort));
+        } else if (UDP.equals(protocol)) {
+            sBuilder.matchIPProtocol(IPv4.PROTOCOL_UDP)
+                    .matchUdpSrc(TpPort.tpPort(podPort));
+        }
+
+        String podIpPrefix = podIp.split("\\.")[0] +
+                                            "." + podIp.split("\\.")[1];
 
         ExtensionTreatment loadTreatment = buildLoadExtension(
-                deviceService.getDevice(deviceId), "dst", "10.10");
+                deviceService.getDevice(deviceId), DST, podIpPrefix);
 
-        TrafficTreatment treatment = DefaultTrafficTreatment.builder()
+        TrafficTreatment.Builder tBuilder = DefaultTrafficTreatment.builder()
                 .extension(loadTreatment, deviceId)
                 .setIpSrc(IpAddress.valueOf(serviceIp))
-                .setTcpSrc(TpPort.tpPort(servicePort))
-                .transition(ROUTING_TABLE)
-                .build();
+                .transition(ROUTING_TABLE);
+
+        if (TCP.equals(protocol)) {
+            tBuilder.setTcpSrc(TpPort.tpPort(servicePort));
+        } else if (UDP.equals(protocol)) {
+            tBuilder.setUdpSrc(TpPort.tpPort(servicePort));
+        }
 
         k8sFlowRuleService.setRule(
                 appId,
                 deviceId,
-                selector,
-                treatment,
+                sBuilder.build(),
+                tBuilder.build(),
                 priority,
                 installTable,
                 install);
@@ -597,17 +665,17 @@
                 return;
             }
 
-            long ctTrackNew = computeCtStateFlag(true, true, false);
-            long ctMaskTrackNew = computeCtMaskFlag(true, true, false);
+            if (NAT_STATEFUL.equals(serviceIpNatMode)) {
+                long ctTrackNew = computeCtStateFlag(true, true, false);
+                long ctMaskTrackNew = computeCtMaskFlag(true, true, false);
 
-            k8sNodeService.completeNodes().forEach(n -> {
-                if (NAT_STATEFUL.equals(serviceIpNatMode)) {
-                    setStatefulGroupFlowRules(n.intgBridge(), ctTrackNew,
-                            ctMaskTrackNew, service, true);
-                } else if (NAT_STATELESS.equals(serviceIpNatMode)) {
-                    setStatelessGroupFlowRules(n.intgBridge(), service, true);
-                }
-            });
+                k8sNodeService.completeNodes().forEach(n ->
+                        setStatefulGroupFlowRules(n.intgBridge(), ctTrackNew,
+                                ctMaskTrackNew, service, true));
+            } else if (NAT_STATELESS.equals(serviceIpNatMode)) {
+                k8sNodeService.completeNodes().forEach(n ->
+                        setStatelessGroupFlowRules(n.intgBridge(), service, true));
+            }
         }
 
         private void processServiceRemoval(Service service) {
@@ -615,17 +683,17 @@
                 return;
             }
 
-            long ctTrackNew = computeCtStateFlag(true, true, false);
-            long ctMaskTrackNew = computeCtMaskFlag(true, true, false);
+            if (NAT_STATEFUL.equals(serviceIpNatMode)) {
+                long ctTrackNew = computeCtStateFlag(true, true, false);
+                long ctMaskTrackNew = computeCtMaskFlag(true, true, false);
 
-            k8sNodeService.completeNodes().forEach(n -> {
-                if (NAT_STATEFUL.equals(serviceIpNatMode)) {
-                    setStatefulGroupFlowRules(n.intgBridge(), ctTrackNew,
-                            ctMaskTrackNew, service, false);
-                } else if (NAT_STATELESS.equals(serviceIpNatMode)) {
-                    setStatelessGroupFlowRules(n.intgBridge(), service, false);
-                }
-            });
+                k8sNodeService.completeNodes().forEach(n ->
+                        setStatefulGroupFlowRules(n.intgBridge(), ctTrackNew,
+                                ctMaskTrackNew, service, false));
+            } else if (NAT_STATELESS.equals(serviceIpNatMode)) {
+                k8sNodeService.completeNodes().forEach(n ->
+                        setStatelessGroupFlowRules(n.intgBridge(), service, false));
+            }
         }
     }
 
@@ -677,7 +745,6 @@
             } else {
                 log.warn("Service IP NAT mode was not configured!");
             }
-
         }
     }
 }
diff --git a/apps/k8s-networking/app/src/main/java/org/onosproject/k8snetworking/util/RulePopulatorUtil.java b/apps/k8s-networking/app/src/main/java/org/onosproject/k8snetworking/util/RulePopulatorUtil.java
index 7605a5e..25a44ea 100644
--- a/apps/k8s-networking/app/src/main/java/org/onosproject/k8snetworking/util/RulePopulatorUtil.java
+++ b/apps/k8s-networking/app/src/main/java/org/onosproject/k8snetworking/util/RulePopulatorUtil.java
@@ -39,6 +39,8 @@
 import java.util.ArrayList;
 import java.util.List;
 
+import static org.onosproject.k8snetworking.api.Constants.DST;
+import static org.onosproject.k8snetworking.api.Constants.SRC;
 import static org.onosproject.net.flow.instructions.ExtensionTreatmentType.ExtensionTreatmentTypes.NICIRA_LOAD;
 import static org.onosproject.net.flow.instructions.ExtensionTreatmentType.ExtensionTreatmentTypes.NICIRA_RESUBMIT_TABLE;
 import static org.onosproject.net.flow.instructions.ExtensionTreatmentType.ExtensionTreatmentTypes.NICIRA_SET_TUNNEL_DST;
@@ -93,9 +95,6 @@
     private static final int SRC_IP = 0x00000e04;
     private static final int DST_IP = 0x00001004;
 
-    private static final String SRC = "src";
-    private static final String DST = "dst";
-
     private static final int OFF_SET_BIT = 16;
     private static final int REMAINDER_BIT = 16;