Secure LLDP-based Topology Detection

Current LLDP/BDDP-based Topology Detection is vulnerable to the
creation of fake links via forged, modified, or replayed LLDP packets.
This patch fixes this vulnerability by authenticating LLDP/BDDP packets
using a Message Authentication Code and adding a timestamp to prevent
replay. We use HMAC with SHA-256 has our Messge Authentication Code and
derive the key from the config/cluster.json file via the
ClusterMetadata class.

Change-Id: I01dd6edc5cffd6dfe274bcdb97189f2661a6c4f1
diff --git a/utils/misc/src/main/java/org/onlab/packet/ONOSLLDP.java b/utils/misc/src/main/java/org/onlab/packet/ONOSLLDP.java
index b1e44c5..f0b981c 100644
--- a/utils/misc/src/main/java/org/onlab/packet/ONOSLLDP.java
+++ b/utils/misc/src/main/java/org/onlab/packet/ONOSLLDP.java
@@ -21,9 +21,14 @@
 
 import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
+import java.security.InvalidKeyException;
+import java.security.NoSuchAlgorithmException;
 import java.util.Arrays;
 import java.util.HashMap;
 
+import javax.crypto.Mac;
+import javax.crypto.spec.SecretKeySpec;
+
 import static org.onlab.packet.LLDPOrganizationalTLV.OUI_LENGTH;
 import static org.onlab.packet.LLDPOrganizationalTLV.SUBTYPE_LENGTH;
 
@@ -39,10 +44,14 @@
     protected static final byte NAME_SUBTYPE = 1;
     protected static final byte DEVICE_SUBTYPE = 2;
     protected static final byte DOMAIN_SUBTYPE = 3;
+    protected static final byte TIMESTAMP_SUBTYPE = 4;
+    protected static final byte SIG_SUBTYPE = 5;
 
     private static final short NAME_LENGTH = OUI_LENGTH + SUBTYPE_LENGTH;
     private static final short DEVICE_LENGTH = OUI_LENGTH + SUBTYPE_LENGTH;
     private static final short DOMAIN_LENGTH = OUI_LENGTH + SUBTYPE_LENGTH;
+    private static final short TIMESTAMP_LENGTH = OUI_LENGTH + SUBTYPE_LENGTH;
+    private static final short SIG_LENGTH = OUI_LENGTH + SUBTYPE_LENGTH;
 
     private final HashMap<Byte, LLDPOrganizationalTLV> opttlvs = Maps.newHashMap();
 
@@ -138,6 +147,28 @@
         this.setPortId(portTLV);
     }
 
+    public void setTimestamp(long timestamp) {
+        LLDPOrganizationalTLV tmtlv = opttlvs.get(TIMESTAMP_SUBTYPE);
+        if (tmtlv == null) {
+            return;
+        }
+        tmtlv.setInfoString(ByteBuffer.allocate(8).putLong(timestamp).array());
+        tmtlv.setLength((short) (8 + TIMESTAMP_LENGTH));
+        tmtlv.setSubType(TIMESTAMP_SUBTYPE);
+        tmtlv.setOUI(MacAddress.ONOS.oui());
+    }
+
+    public void setSig(byte[] sig) {
+        LLDPOrganizationalTLV sigtlv = opttlvs.get(SIG_SUBTYPE);
+        if (sigtlv == null) {
+            return;
+        }
+        sigtlv.setInfoString(sig);
+        sigtlv.setLength((short) (sig.length + SIG_LENGTH));
+        sigtlv.setSubType(SIG_SUBTYPE);
+        sigtlv.setOUI(MacAddress.ONOS.oui());
+    }
+
     public LLDPOrganizationalTLV getNameTLV() {
         for (LLDPTLV tlv : this.getOptionalTLVList()) {
             if (tlv.getType() == LLDPOrganizationalTLV.ORGANIZATIONAL_TLV_TYPE) {
@@ -153,7 +184,7 @@
     public LLDPOrganizationalTLV getDeviceTLV() {
         for (LLDPTLV tlv : this.getOptionalTLVList()) {
             if (tlv.getType() == LLDPOrganizationalTLV.ORGANIZATIONAL_TLV_TYPE) {
-                LLDPOrganizationalTLV orgTLV =  (LLDPOrganizationalTLV) tlv;
+                LLDPOrganizationalTLV orgTLV = (LLDPOrganizationalTLV) tlv;
                 if (orgTLV.getSubType() == DEVICE_SUBTYPE) {
                     return orgTLV;
                 }
@@ -162,6 +193,30 @@
         return null;
     }
 
+    public LLDPOrganizationalTLV getTimestampTLV() {
+        for (LLDPTLV tlv : this.getOptionalTLVList()) {
+            if (tlv.getType() == LLDPOrganizationalTLV.ORGANIZATIONAL_TLV_TYPE) {
+                LLDPOrganizationalTLV orgTLV =  (LLDPOrganizationalTLV) tlv;
+                if (orgTLV.getSubType() == TIMESTAMP_SUBTYPE) {
+                    return orgTLV;
+                }
+            }
+        }
+        return null;
+    }
+
+    public LLDPOrganizationalTLV getSigTLV() {
+        for (LLDPTLV tlv : this.getOptionalTLVList()) {
+            if (tlv.getType() == LLDPOrganizationalTLV.ORGANIZATIONAL_TLV_TYPE) {
+                LLDPOrganizationalTLV orgTLV = (LLDPOrganizationalTLV) tlv;
+                if (orgTLV.getSubType() == SIG_SUBTYPE) {
+                    return orgTLV;
+                }
+            }
+        }
+        return null;
+    }
+
     /**
      * Gets the TLV associated with remote probing. This TLV will be null if
      * remote probing is disabled.
@@ -212,6 +267,24 @@
                 portBB.position(), portBB.remaining(), StandardCharsets.UTF_8));
     }
 
+    public long getTimestamp() {
+        LLDPOrganizationalTLV tlv = getTimestampTLV();
+        if (tlv != null) {
+            ByteBuffer b = ByteBuffer.allocate(8).put(tlv.getInfoString());
+            b.flip();
+            return b.getLong();
+        }
+        return 0;
+    }
+
+    public byte[] getSig() {
+        LLDPOrganizationalTLV tlv = getSigTLV();
+        if (tlv != null) {
+            return tlv.getInfoString();
+        }
+        return null;
+    }
+
     /**
      * Given an ethernet packet, determines if this is an LLDP from
      * ONOS and returns the device the LLDP came from.
@@ -231,12 +304,14 @@
 
     /**
      * Creates a link probe for link discovery/verification.
+     * @deprecated since 1.15. Insecure, do not use.
      *
      * @param deviceId The device ID as a String
      * @param chassisId The chassis ID of the device
      * @param portNum Port number of port to send probe out of
      * @return ONOSLLDP probe message
      */
+   @Deprecated
     public static ONOSLLDP onosLLDP(String deviceId, ChassisId chassisId, int portNum) {
         ONOSLLDP probe = new ONOSLLDP(NAME_SUBTYPE, DEVICE_SUBTYPE);
         probe.setPortId(portNum);
@@ -251,13 +326,69 @@
      * @param deviceId The device ID as a String
      * @param chassisId The chassis ID of the device
      * @param portNum Port number of port to send probe out of
+     * @param secret LLDP secret
+     * @return ONOSLLDP probe message
+     */
+    public static ONOSLLDP onosSecureLLDP(String deviceId, ChassisId chassisId, int portNum, String secret) {
+        ONOSLLDP probe = null;
+        if (secret == null) {
+            probe = new ONOSLLDP(NAME_SUBTYPE, DEVICE_SUBTYPE);
+        } else {
+            probe = new ONOSLLDP(NAME_SUBTYPE, DEVICE_SUBTYPE, TIMESTAMP_SUBTYPE, SIG_SUBTYPE);
+        }
+        probe.setPortId(portNum);
+        probe.setDevice(deviceId);
+        probe.setChassisId(chassisId);
+
+        if (secret != null) {
+            /* Secure Mode */
+            long ts = System.currentTimeMillis();
+            probe.setTimestamp(ts);
+            byte[] sig = createSig(deviceId, portNum, ts, secret);
+            if (sig == null) {
+                return null;
+            }
+            probe.setSig(sig);
+            sig = null;
+        }
+        return probe;
+    }
+
+    /**
+     * Creates a link probe for link discovery/verification.
+     * @deprecated since 1.15. Insecure, do not use.
+     *
+     * @param deviceId The device ID as a String
+     * @param chassisId The chassis ID of the device
+     * @param portNum Port number of port to send probe out of
      * @param portDesc Port description of port to send probe out of
      * @return ONOSLLDP probe message
      */
+    @Deprecated
     public static ONOSLLDP onosLLDP(String deviceId, ChassisId chassisId, int portNum, String portDesc) {
-
         ONOSLLDP probe = onosLLDP(deviceId, chassisId, portNum);
+        addPortDesc(probe, portDesc);
+        return probe;
+    }
 
+    /**
+     * Creates a link probe for link discovery/verification.
+     *
+     * @param deviceId  The device ID as a String
+     * @param chassisId The chassis ID of the device
+     * @param portNum   Port number of port to send probe out of
+     * @param portDesc  Port description of port to send probe out of
+     * @param secret    LLDP secret
+     * @return ONOSLLDP probe message
+     */
+    public static ONOSLLDP onosSecureLLDP(String deviceId, ChassisId chassisId, int portNum, String portDesc,
+                                          String secret) {
+        ONOSLLDP probe = onosSecureLLDP(deviceId, chassisId, portNum, secret);
+        addPortDesc(probe, portDesc);
+        return probe;
+    }
+
+    private static void addPortDesc(ONOSLLDP probe, String portDesc) {
         if (portDesc != null && !portDesc.isEmpty()) {
             byte[] bPortDesc = portDesc.getBytes(StandardCharsets.UTF_8);
 
@@ -270,7 +401,70 @@
                     .setValue(bPortDesc);
             probe.addOptionalTLV(portDescTlv);
         }
-        return probe;
+    }
+
+    private static byte[] createSig(String deviceId, int portNum, long timestamp, String secret) {
+        byte[] pnb = ByteBuffer.allocate(8).putLong(portNum).array();
+        byte[] tmb = ByteBuffer.allocate(8).putLong(timestamp).array();
+
+        try {
+            SecretKeySpec signingKey = new SecretKeySpec(secret.getBytes(StandardCharsets.UTF_8), "HmacSHA256");
+            Mac mac = Mac.getInstance("HmacSHA256");
+            mac.init(signingKey);
+            mac.update(deviceId.getBytes());
+            mac.update(pnb);
+            mac.update(tmb);
+            byte[] sig = mac.doFinal();
+            return sig;
+        } catch (NoSuchAlgorithmException e) {
+            return null;
+        } catch (InvalidKeyException e) {
+            return null;
+        }
+    }
+
+    private static boolean verifySig(byte[] sig, String deviceId, int portNum, long timestamp, String secret) {
+        byte[] nsig = createSig(deviceId, portNum, timestamp, secret);
+        if (nsig == null) {
+            return false;
+        }
+
+        if (!ArrayUtils.isSameLength(nsig, sig)) {
+            return false;
+        }
+
+        boolean fail = false;
+        for (int i = 0; i < nsig.length; i++) {
+            if (sig[i] != nsig[i]) {
+                fail = true;
+            }
+        }
+        if (fail) {
+            return false;
+        }
+        return true;
+    }
+
+    public static boolean verify(ONOSLLDP probe, String secret, long maxDelay) {
+        if (secret == null) {
+            return true;
+        }
+
+        String deviceId = probe.getDeviceString();
+        int portNum = probe.getPort();
+        long timestamp = probe.getTimestamp();
+        byte[] sig = probe.getSig();
+
+        if (deviceId == null || sig == null) {
+            return false;
+        }
+
+        if (timestamp + maxDelay <= System.currentTimeMillis() ||
+                timestamp > System.currentTimeMillis()) {
+            return false;
+        }
+
+        return verifySig(sig, deviceId, portNum, timestamp, secret);
     }
 
 }
diff --git a/utils/misc/src/test/java/org/onlab/packet/ONOSLLDPTest.java b/utils/misc/src/test/java/org/onlab/packet/ONOSLLDPTest.java
index 268b5f5..b2494b5 100644
--- a/utils/misc/src/test/java/org/onlab/packet/ONOSLLDPTest.java
+++ b/utils/misc/src/test/java/org/onlab/packet/ONOSLLDPTest.java
@@ -30,8 +30,9 @@
     private static final Integer PORT_NUMBER = 2;
     private static final Integer PORT_NUMBER_2 = 98761234;
     private static final String PORT_DESC = "Ethernet1";
+    private static final String TEST_SECRET = "test";
 
-    private ONOSLLDP onoslldp = ONOSLLDP.onosLLDP(DEVICE_ID, CHASSIS_ID, PORT_NUMBER, PORT_DESC);
+    private ONOSLLDP onoslldp = ONOSLLDP.onosSecureLLDP(DEVICE_ID, CHASSIS_ID, PORT_NUMBER, PORT_DESC, TEST_SECRET);
 
     /**
      * Tests port number and getters.