[CORD-1751] Preventing attacks on DHCP-Relay

Change-Id: I46f7ba2490994e71c9f7d881cbe44785720f1e37
diff --git a/apps/dhcprelay/src/main/java/org/onosproject/dhcprelay/Dhcp4HandlerImpl.java b/apps/dhcprelay/src/main/java/org/onosproject/dhcprelay/Dhcp4HandlerImpl.java
index a65ebc7..d15a34a 100644
--- a/apps/dhcprelay/src/main/java/org/onosproject/dhcprelay/Dhcp4HandlerImpl.java
+++ b/apps/dhcprelay/src/main/java/org/onosproject/dhcprelay/Dhcp4HandlerImpl.java
@@ -17,7 +17,10 @@
 
 package org.onosproject.dhcprelay;
 
+import com.google.common.collect.HashMultimap;
+import com.google.common.collect.ImmutableSet;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Multimap;
 import com.google.common.collect.Sets;
 import org.apache.felix.scr.annotations.Activate;
 import org.apache.felix.scr.annotations.Component;
@@ -33,16 +36,32 @@
 import org.onlab.packet.Ip4Address;
 import org.onlab.packet.IpAddress;
 import org.onlab.packet.MacAddress;
+import org.onlab.packet.TpPort;
 import org.onlab.packet.UDP;
 import org.onlab.packet.VlanId;
 import org.onlab.packet.dhcp.CircuitId;
 import org.onlab.packet.dhcp.DhcpOption;
 import org.onlab.packet.dhcp.DhcpRelayAgentOption;
+import org.onosproject.core.ApplicationId;
+import org.onosproject.core.CoreService;
 import org.onosproject.dhcprelay.api.DhcpHandler;
 import org.onosproject.dhcprelay.api.DhcpServerInfo;
 import org.onosproject.dhcprelay.config.DhcpServerConfig;
+import org.onosproject.dhcprelay.config.IgnoreDhcpConfig;
 import org.onosproject.dhcprelay.store.DhcpRecord;
 import org.onosproject.dhcprelay.store.DhcpRelayStore;
+import org.onosproject.net.Device;
+import org.onosproject.net.DeviceId;
+import org.onosproject.net.behaviour.Pipeliner;
+import org.onosproject.net.device.DeviceService;
+import org.onosproject.net.flow.DefaultTrafficSelector;
+import org.onosproject.net.flow.TrafficSelector;
+import org.onosproject.net.flowobjective.DefaultForwardingObjective;
+import org.onosproject.net.flowobjective.FlowObjectiveService;
+import org.onosproject.net.flowobjective.ForwardingObjective;
+import org.onosproject.net.flowobjective.Objective;
+import org.onosproject.net.flowobjective.ObjectiveContext;
+import org.onosproject.net.flowobjective.ObjectiveError;
 import org.onosproject.net.host.HostEvent;
 import org.onosproject.net.host.HostListener;
 import org.onosproject.net.host.HostProvider;
@@ -50,6 +69,7 @@
 import org.onosproject.net.host.HostProviderService;
 import org.onosproject.net.intf.Interface;
 import org.onosproject.net.intf.InterfaceService;
+import org.onosproject.net.packet.PacketPriority;
 import org.onosproject.net.provider.ProviderId;
 import org.onosproject.routeservice.Route;
 import org.onosproject.routeservice.RouteStore;
@@ -76,6 +96,7 @@
 import java.util.List;
 import java.util.Optional;
 import java.util.Set;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.stream.Collectors;
 
 import static com.google.common.base.Preconditions.checkNotNull;
@@ -85,6 +106,8 @@
 import static org.onlab.packet.DHCP.DHCPOptionCode.OptionCode_MessageType;
 import static org.onlab.packet.MacAddress.valueOf;
 import static org.onlab.packet.dhcp.DhcpRelayAgentOption.RelayAgentInfoOptions.CIRCUIT_ID;
+import static org.onosproject.net.flowobjective.Objective.Operation.ADD;
+import static org.onosproject.net.flowobjective.Objective.Operation.REMOVE;
 
 @Component
 @Service
@@ -92,6 +115,27 @@
 public class Dhcp4HandlerImpl implements DhcpHandler, HostProvider {
     public static final String DHCP_V4_RELAY_APP = "org.onosproject.Dhcp4HandlerImpl";
     public static final ProviderId PROVIDER_ID = new ProviderId("dhcp4", DHCP_V4_RELAY_APP);
+    private static final String BROADCAST_IP = "255.255.255.255";
+    private static final int IGNORE_CONTROL_PRIORITY = PacketPriority.CONTROL.priorityValue() + 1000;
+
+    private static final TrafficSelector CLIENT_SERVER_SELECTOR = DefaultTrafficSelector.builder()
+            .matchEthType(Ethernet.TYPE_IPV4)
+            .matchIPProtocol(IPv4.PROTOCOL_UDP)
+            .matchIPSrc(Ip4Address.ZERO.toIpPrefix())
+            .matchIPDst(Ip4Address.valueOf(BROADCAST_IP).toIpPrefix())
+            .matchUdpSrc(TpPort.tpPort(UDP.DHCP_CLIENT_PORT))
+            .matchUdpDst(TpPort.tpPort(UDP.DHCP_SERVER_PORT))
+            .build();
+    private static final TrafficSelector SERVER_RELAY_SELECTOR = DefaultTrafficSelector.builder()
+            .matchEthType(Ethernet.TYPE_IPV4)
+            .matchIPProtocol(IPv4.PROTOCOL_UDP)
+            .matchUdpSrc(TpPort.tpPort(UDP.DHCP_SERVER_PORT))
+            .matchUdpDst(TpPort.tpPort(UDP.DHCP_SERVER_PORT))
+            .build();
+    static final Set<TrafficSelector> DHCP_SELECTORS = ImmutableSet.of(
+            CLIENT_SERVER_SELECTOR,
+            SERVER_RELAY_SELECTOR
+    );
     private static Logger log = LoggerFactory.getLogger(Dhcp4HandlerImpl.class);
 
     @Reference(cardinality = ReferenceCardinality.MANDATORY_UNARY)
@@ -112,7 +156,18 @@
     @Reference(cardinality = ReferenceCardinality.MANDATORY_UNARY)
     protected HostProviderRegistry providerRegistry;
 
+    @Reference(cardinality = ReferenceCardinality.MANDATORY_UNARY)
+    protected CoreService coreService;
+
+    @Reference(cardinality = ReferenceCardinality.MANDATORY_UNARY)
+    protected DeviceService deviceService;
+
+    @Reference(cardinality = ReferenceCardinality.MANDATORY_UNARY)
+    protected FlowObjectiveService flowObjectiveService;
+
     protected HostProviderService providerService;
+    protected ApplicationId appId;
+    protected Multimap<DeviceId, VlanId> ignoredVlans = HashMultimap.create();
     private InternalHostListener hostListener = new InternalHostListener();
 
     private List<DhcpServerInfo> defaultServerInfoList = Lists.newArrayList();
@@ -120,6 +175,7 @@
 
     @Activate
     protected void activate() {
+        appId = coreService.registerApplication(DHCP_V4_RELAY_APP);
         hostService.addListener(hostListener);
         providerService = providerRegistry.register(this);
     }
@@ -163,6 +219,30 @@
         return indirectServerInfoList;
     }
 
+    @Override
+    public void updateIgnoreVlanConfig(IgnoreDhcpConfig config) {
+        if (config == null) {
+            ignoredVlans.forEach(((deviceId, vlanId) -> {
+                processIgnoreVlanRule(deviceId, vlanId, REMOVE);
+            }));
+            return;
+        }
+        config.ignoredVlans().forEach((deviceId, vlanId) -> {
+            if (ignoredVlans.get(deviceId).contains(vlanId)) {
+                // don't need to process if it already ignored
+                return;
+            }
+            processIgnoreVlanRule(deviceId, vlanId, ADD);
+        });
+
+        ignoredVlans.forEach((deviceId, vlanId) -> {
+            if (!config.ignoredVlans().get(deviceId).contains(vlanId)) {
+                // not contains in new config, remove it
+                processIgnoreVlanRule(deviceId, vlanId, REMOVE);
+            }
+        });
+    }
+
     public void setDhcpServerConfigs(Collection<DhcpServerConfig> configs, List<DhcpServerInfo> serverInfoList) {
         if (configs.size() == 0) {
             // no config to update
@@ -188,6 +268,7 @@
             });
             oldServerInfo.getDhcpServerIp4().ifPresent(serverIp -> {
                 hostService.stopMonitoringIp(serverIp);
+                cancelDhcpPacket(serverIp);
             });
         }
 
@@ -202,7 +283,8 @@
         log.debug("DHCP server connect point: {}", newServerInfo.getDhcpServerConnectPoint().orElse(null));
         log.debug("DHCP server IP: {}", newServerInfo.getDhcpServerIp4().orElse(null));
 
-        IpAddress ipToProbe;
+        Ip4Address serverIp = newServerInfo.getDhcpServerIp4().get();
+        Ip4Address ipToProbe;
         if (newServerInfo.getDhcpGatewayIp4().isPresent()) {
             ipToProbe = newServerInfo.getDhcpGatewayIp4().get();
         } else {
@@ -228,6 +310,7 @@
         nonDupServerInfoList.addAll(serverInfoList);
         serverInfoList.clear();
         serverInfoList.addAll(nonDupServerInfoList);
+        requestDhcpPacket(serverIp);
     }
 
     @Override
@@ -414,7 +497,6 @@
             indirectRelayAgentIp = indirectServerInfo.getRelayAgentIp4().orElse(null);
         }
 
-
         Ip4Address clientInterfaceIp =
                 interfaceService.getInterfacesByPort(context.inPacket().receivedFrom())
                         .stream()
@@ -982,9 +1064,6 @@
                 case HOST_REMOVED:
                     hostRemoved(event.subject());
                     break;
-                case HOST_MOVED:
-                    hostMoved(event.subject());
-                    break;
                 default:
                     break;
             }
@@ -992,114 +1071,39 @@
     }
 
     /**
-     * Handle host move.
-     * If the host DHCP server or gateway and it moved to the location different
-     * to user configured, unsets the connect mac and vlan
-     *
-     * @param host the host
-     */
-    private void hostMoved(Host host) {
-        Set<ConnectPoint> hostConnectPoints = host.locations().stream()
-                .map(hl -> new ConnectPoint(hl.elementId(), hl.port()))
-                .collect(Collectors.toSet());
-        DhcpServerInfo serverInfo;
-        ConnectPoint dhcpServerConnectPoint;
-        Ip4Address dhcpGatewayIp;
-        Ip4Address dhcpServerIp;
-
-        if (!defaultServerInfoList.isEmpty()) {
-            serverInfo = defaultServerInfoList.get(0);
-            dhcpServerConnectPoint = serverInfo.getDhcpServerConnectPoint().orElse(null);
-            dhcpGatewayIp = serverInfo.getDhcpGatewayIp4().orElse(null);
-            dhcpServerIp = serverInfo.getDhcpServerIp4().orElse(null);
-            if (dhcpGatewayIp != null) {
-                if (host.ipAddresses().contains(dhcpGatewayIp) &&
-                        !hostConnectPoints.contains(dhcpServerConnectPoint)) {
-                    serverInfo.setDhcpConnectVlan(null);
-                    serverInfo.setDhcpConnectMac(null);
-                }
-            }
-            if (dhcpServerIp != null) {
-                if (host.ipAddresses().contains(dhcpServerIp) &&
-                        !hostConnectPoints.contains(dhcpServerConnectPoint)) {
-                    serverInfo.setDhcpConnectVlan(null);
-                    serverInfo.setDhcpConnectMac(null);
-                }
-            }
-        }
-
-        if (!indirectServerInfoList.isEmpty()) {
-            // Indirect server
-            serverInfo = indirectServerInfoList.get(0);
-            dhcpServerConnectPoint = serverInfo.getDhcpServerConnectPoint().orElse(null);
-            dhcpGatewayIp = serverInfo.getDhcpGatewayIp4().orElse(null);
-            dhcpServerIp = serverInfo.getDhcpServerIp4().orElse(null);
-            if (dhcpGatewayIp != null) {
-                if (host.ipAddresses().contains(dhcpGatewayIp) &&
-                        !hostConnectPoints.contains(dhcpServerConnectPoint)) {
-                    serverInfo.setDhcpConnectVlan(null);
-                    serverInfo.setDhcpConnectMac(null);
-                }
-            }
-            if (dhcpServerIp != null) {
-                if (host.ipAddresses().contains(dhcpServerIp) &&
-                        !hostConnectPoints.contains(dhcpServerConnectPoint)) {
-                    serverInfo.setDhcpConnectVlan(null);
-                    serverInfo.setDhcpConnectMac(null);
-                }
-            }
-        }
-    }
-
-    /**
      * Handle host updated.
      * If the host is DHCP server or gateway, update connect mac and vlan.
      *
      * @param host the host
      */
     private void hostUpdated(Host host) {
-        DhcpServerInfo serverInfo;
-        Ip4Address dhcpGatewayIp;
-        Ip4Address dhcpServerIp;
-
-        if (!defaultServerInfoList.isEmpty()) {
-            serverInfo = defaultServerInfoList.get(0);
-            dhcpGatewayIp = serverInfo.getDhcpGatewayIp4().orElse(null);
-            dhcpServerIp = serverInfo.getDhcpServerIp4().orElse(null);
-            if (dhcpGatewayIp != null) {
-                if (host.ipAddresses().contains(dhcpGatewayIp)) {
-                    serverInfo.setDhcpConnectMac(host.mac());
-                    serverInfo.setDhcpConnectVlan(host.vlan());
-                }
-            }
-            if (dhcpServerIp != null) {
-                if (host.ipAddresses().contains(dhcpServerIp)) {
-                    serverInfo.setDhcpConnectMac(host.mac());
-                    serverInfo.setDhcpConnectVlan(host.vlan());
-                }
-            }
-        }
-
-        if (!indirectServerInfoList.isEmpty()) {
-            serverInfo = indirectServerInfoList.get(0);
-            dhcpGatewayIp = serverInfo.getDhcpGatewayIp4().orElse(null);
-            dhcpServerIp = serverInfo.getDhcpServerIp4().orElse(null);
-            if (dhcpGatewayIp != null) {
-                if (host.ipAddresses().contains(dhcpGatewayIp)) {
-                    serverInfo.setDhcpConnectMac(host.mac());
-                    serverInfo.setDhcpConnectVlan(host.vlan());
-                }
-            }
-            if (dhcpServerIp != null) {
-                if (host.ipAddresses().contains(dhcpServerIp)) {
-                    serverInfo.setDhcpConnectMac(host.mac());
-                    serverInfo.setDhcpConnectVlan(host.vlan());
-                }
-            }
-        }
-
+        hostUpdated(host, defaultServerInfoList);
+        hostUpdated(host, indirectServerInfoList);
     }
 
+    private void hostUpdated(Host host, List<DhcpServerInfo> srverInfoList) {
+        DhcpServerInfo serverInfo;
+        Ip4Address targetIp;
+        if (!srverInfoList.isEmpty()) {
+            serverInfo = srverInfoList.get(0);
+            targetIp = serverInfo.getDhcpGatewayIp4().orElse(null);
+            Ip4Address serverIp = serverInfo.getDhcpServerIp4().orElse(null);
+
+            if (targetIp == null) {
+                targetIp = serverIp;
+            }
+
+            if (targetIp != null) {
+                if (host.ipAddresses().contains(targetIp)) {
+                    serverInfo.setDhcpConnectMac(host.mac());
+                    serverInfo.setDhcpConnectVlan(host.vlan());
+                    requestDhcpPacket(serverIp);
+                }
+            }
+        }
+    }
+
+
     /**
      * Handle host removed.
      * If the host is DHCP server or gateway, unset connect mac and vlan.
@@ -1107,45 +1111,165 @@
      * @param host the host
      */
     private void hostRemoved(Host host) {
+        hostRemoved(host, defaultServerInfoList);
+        hostRemoved(host, indirectServerInfoList);
+    }
+
+    private void hostRemoved(Host host, List<DhcpServerInfo> serverInfoList) {
         DhcpServerInfo serverInfo;
-        Ip4Address dhcpGatewayIp;
-        Ip4Address dhcpServerIp;
-        if (!defaultServerInfoList.isEmpty()) {
-            serverInfo = defaultServerInfoList.get(0);
-            dhcpGatewayIp = serverInfo.getDhcpGatewayIp4().orElse(null);
-            dhcpServerIp = serverInfo.getDhcpServerIp4().orElse(null);
+        Ip4Address targetIp;
+        if (!serverInfoList.isEmpty()) {
+            serverInfo = serverInfoList.get(0);
+            Ip4Address serverIp = serverInfo.getDhcpServerIp4().orElse(null);
+            targetIp = serverInfo.getDhcpGatewayIp4().orElse(null);
 
-            if (dhcpGatewayIp != null) {
-                if (host.ipAddresses().contains(dhcpGatewayIp)) {
-                    serverInfo.setDhcpConnectVlan(null);
-                    serverInfo.setDhcpConnectMac(null);
-                }
+            if (targetIp == null) {
+                targetIp = serverIp;
             }
-            if (dhcpServerIp != null) {
-                if (host.ipAddresses().contains(dhcpServerIp)) {
+
+            if (targetIp != null) {
+                if (host.ipAddresses().contains(targetIp)) {
                     serverInfo.setDhcpConnectVlan(null);
                     serverInfo.setDhcpConnectMac(null);
+                    cancelDhcpPacket(serverIp);
                 }
             }
         }
+    }
 
-        if (!indirectServerInfoList.isEmpty()) {
-            serverInfo = indirectServerInfoList.get(0);
-            dhcpGatewayIp = serverInfo.getDhcpGatewayIp4().orElse(null);
-            dhcpServerIp = serverInfo.getDhcpServerIp4().orElse(null);
+    private void requestDhcpPacket(Ip4Address serverIp) {
+        requestServerDhcpPacket(serverIp);
+        requestClientDhcpPacket(serverIp);
+    }
 
-            if (dhcpGatewayIp != null) {
-                if (host.ipAddresses().contains(dhcpGatewayIp)) {
-                    serverInfo.setDhcpConnectVlan(null);
-                    serverInfo.setDhcpConnectMac(null);
+    private void cancelDhcpPacket(Ip4Address serverIp) {
+        cancelServerDhcpPacket(serverIp);
+        cancelClientDhcpPacket(serverIp);
+    }
+
+    private void cancelServerDhcpPacket(Ip4Address serverIp) {
+        TrafficSelector serverSelector =
+                DefaultTrafficSelector.builder(SERVER_RELAY_SELECTOR)
+                        .matchIPSrc(serverIp.toIpPrefix())
+                        .build();
+        packetService.cancelPackets(serverSelector,
+                                    PacketPriority.CONTROL,
+                                    appId);
+    }
+
+    private void requestServerDhcpPacket(Ip4Address serverIp) {
+        TrafficSelector serverSelector =
+                DefaultTrafficSelector.builder(SERVER_RELAY_SELECTOR)
+                        .matchIPSrc(serverIp.toIpPrefix())
+                        .build();
+        packetService.requestPackets(serverSelector,
+                                     PacketPriority.CONTROL,
+                                     appId);
+    }
+
+    private void cancelClientDhcpPacket(Ip4Address serverIp) {
+        // Packet comes from relay
+        TrafficSelector indirectClientSelector =
+                DefaultTrafficSelector.builder(SERVER_RELAY_SELECTOR)
+                        .matchIPDst(serverIp.toIpPrefix())
+                        .build();
+        packetService.cancelPackets(indirectClientSelector,
+                                    PacketPriority.CONTROL,
+                                    appId);
+
+        // Packet comes from client
+        packetService.cancelPackets(CLIENT_SERVER_SELECTOR,
+                                    PacketPriority.CONTROL,
+                                    appId);
+    }
+
+    private void requestClientDhcpPacket(Ip4Address serverIp) {
+        // Packet comes from relay
+        TrafficSelector indirectClientSelector =
+                DefaultTrafficSelector.builder(SERVER_RELAY_SELECTOR)
+                        .matchIPDst(serverIp.toIpPrefix())
+                        .build();
+        packetService.requestPackets(indirectClientSelector,
+                                     PacketPriority.CONTROL,
+                                     appId);
+
+        // Packet comes from client
+        packetService.requestPackets(CLIENT_SERVER_SELECTOR,
+                                     PacketPriority.CONTROL,
+                                     appId);
+    }
+
+    /**
+     * Process the ignore rules.
+     *
+     * @param deviceId the device id
+     * @param vlanId the vlan to be ignored
+     * @param op the operation, ADD to install; REMOVE to uninstall rules
+     */
+    private void processIgnoreVlanRule(DeviceId deviceId, VlanId vlanId, Objective.Operation op) {
+        TrafficTreatment dropTreatment = DefaultTrafficTreatment.builder().wipeDeferred().build();
+        AtomicInteger installedCount = new AtomicInteger(DHCP_SELECTORS.size());
+        DHCP_SELECTORS.forEach(trafficSelector -> {
+            TrafficSelector selector = DefaultTrafficSelector.builder(trafficSelector)
+                    .matchVlanId(vlanId)
+                    .build();
+
+            ForwardingObjective.Builder builder = DefaultForwardingObjective.builder()
+                    .withFlag(ForwardingObjective.Flag.VERSATILE)
+                    .withSelector(selector)
+                    .withPriority(IGNORE_CONTROL_PRIORITY)
+                    .withTreatment(dropTreatment)
+                    .fromApp(appId);
+
+
+            ObjectiveContext objectiveContext = new ObjectiveContext() {
+                @Override
+                public void onSuccess(Objective objective) {
+                    log.info("Ignore rule {} (Vlan id {}, device {}, selector {})",
+                             op, vlanId, deviceId, selector);
+                    int countDown = installedCount.decrementAndGet();
+                    if (countDown != 0) {
+                        return;
+                    }
+                    switch (op) {
+                        case ADD:
+                            ignoredVlans.put(deviceId, vlanId);
+                            break;
+                        case REMOVE:
+                            ignoredVlans.remove(deviceId, vlanId);
+                            break;
+                        default:
+                            log.warn("Unsupported objective operation {}", op);
+                            break;
+                    }
                 }
-            }
-            if (dhcpServerIp != null) {
-                if (host.ipAddresses().contains(dhcpServerIp)) {
-                    serverInfo.setDhcpConnectVlan(null);
-                    serverInfo.setDhcpConnectMac(null);
+
+                @Override
+                public void onError(Objective objective, ObjectiveError error) {
+                    log.warn("Can't {} ignore rule (vlan id {}, selector {}, device {}) due to {}",
+                             op, vlanId, selector, deviceId, error);
                 }
+            };
+
+            ForwardingObjective fwd;
+            switch (op) {
+                case ADD:
+                    fwd = builder.add(objectiveContext);
+                    break;
+                case REMOVE:
+                    fwd = builder.remove(objectiveContext);
+                    break;
+                default:
+                    log.warn("Unsupported objective operation {}", op);
+                    return;
             }
-        }
+
+            Device device = deviceService.getDevice(deviceId);
+            if (device == null || !device.is(Pipeliner.class)) {
+                log.warn("Device {} is not available now, wait until device is available", deviceId);
+                return;
+            }
+            flowObjectiveService.apply(deviceId, fwd);
+        });
     }
 }