netmask support in IpAddress

Change-Id: Ie5f4276a1fa1cdd56bebbd3cd1ee74ecacdab598
diff --git a/utils/misc/src/main/java/org/onlab/packet/IpAddress.java b/utils/misc/src/main/java/org/onlab/packet/IpAddress.java
index c664a3a..109672c 100644
--- a/utils/misc/src/main/java/org/onlab/packet/IpAddress.java
+++ b/utils/misc/src/main/java/org/onlab/packet/IpAddress.java
@@ -5,7 +5,7 @@
 /**
  * A class representing an IPv4 address.
  */
-public class IpAddress {
+public final class IpAddress {
 
     //IP Versions
     public enum Version { INET, INET6 };
@@ -14,13 +14,30 @@
     public static final int INET_LEN = 4;
     public static final int INET6_LEN = 16;
 
-    protected Version version;
-    //does it make more sense to have a integral address?
-    protected byte[] octets;
+    //maximum CIDR value
+    public static final int MAX_INET_MASK = 32;
+    public static final int DEFAULT_MASK = 0;
 
-    protected IpAddress(Version ver, byte[] octets) {
+    /**
+     * Default value indicating an unspecified address.
+     */
+    public static final byte [] ANY = new byte [] {0, 0, 0, 0};
+
+    protected Version version;
+
+    protected byte[] octets;
+    protected int netmask;
+
+    private IpAddress(Version ver, byte[] octets, int netmask) {
         this.version = ver;
         this.octets = Arrays.copyOf(octets, INET_LEN);
+        this.netmask = netmask;
+    }
+
+    private IpAddress(Version ver, byte[] octets) {
+        this.version = ver;
+        this.octets = Arrays.copyOf(octets, INET_LEN);
+        this.netmask = DEFAULT_MASK;
     }
 
     /**
@@ -34,38 +51,87 @@
     }
 
     /**
+     * Converts a byte array into an IP address.
+     *
+     * @param address a byte array
+     * @param netmask the CIDR value subnet mask
+     * @return an IP address
+     */
+    public static IpAddress valueOf(byte [] address, int netmask) {
+        return new IpAddress(Version.INET, address, netmask);
+    }
+
+    /**
+     * Helper to convert an integer into a byte array.
+     *
+     * @param address the integer to convert
+     * @return a byte array
+     */
+    private static byte [] bytes(int address) {
+        byte [] bytes = new byte [INET_LEN];
+        for (int i = 0; i < INET_LEN; i++) {
+            bytes[i] = (byte) ((address >> (INET_LEN - (i + 1)) * 8) & 0xff);
+        }
+
+        return bytes;
+    }
+
+    /**
      * Converts an integer into an IPv4 address.
      *
      * @param address an integer representing an IP value
      * @return an IP address
      */
     public static IpAddress valueOf(int address) {
-        byte [] bytes = new byte [INET_LEN];
-        for (int i = 0; i < INET_LEN; i++) {
-            bytes[i] = (byte) ((address >> (INET_LEN - (i + 1)) * 8) & 0xff);
-        }
-
-        return new IpAddress(Version.INET, bytes);
+        return new IpAddress(Version.INET, bytes(address));
     }
 
     /**
-     * Converts a string in dotted-decimal notation (x.x.x.x) into
-     * an IPv4 address.
+     * Converts an integer into an IPv4 address.
      *
-     * @param address a string representing an IP address, e.g. "10.0.0.1"
+     * @param address an integer representing an IP value
+     * @param netmask the CIDR value subnet mask
+     * @return an IP address
+     */
+    public static IpAddress valueOf(int address, int netmask) {
+        return new IpAddress(Version.INET, bytes(address), netmask);
+    }
+
+    /**
+     * Converts a dotted-decimal string (x.x.x.x) into an IPv4 address. The
+     * string can also be in CIDR (slash) notation.
+     *
+     * @param address a IP address in string form, e.g. "10.0.0.1", "10.0.0.1/24"
      * @return an IP address
      */
     public static IpAddress valueOf(String address) {
-        final String [] parts = address.split("\\.");
-        if (parts.length != INET_LEN) {
+
+        final String [] parts = address.split("\\/");
+        if (parts.length > 2) {
+            throw new IllegalArgumentException("Malformed IP address string; "
+                    + "Addres must take form \"x.x.x.x\" or \"x.x.x.x/y\"");
+        }
+
+        int mask = DEFAULT_MASK;
+        if (parts.length == 2) {
+            mask = Integer.valueOf(parts[1]);
+            if (mask > MAX_INET_MASK) {
+                throw new IllegalArgumentException(
+                        "Value of subnet mask cannot exceed "
+                        + MAX_INET_MASK);
+            }
+        }
+
+        final String [] net = parts[0].split("\\.");
+        if (net.length != INET_LEN) {
             throw new IllegalArgumentException("Malformed IP address string; "
                     + "Addres must have four decimal values separated by dots (.)");
         }
         final byte [] bytes = new byte[INET_LEN];
         for (int i = 0; i < INET_LEN; i++) {
-            bytes[i] = Byte.parseByte(parts[i], 10);
+            bytes[i] = (byte) Short.parseShort(net[i], 10);
         }
-        return new IpAddress(Version.INET, bytes);
+        return new IpAddress(Version.INET, bytes, mask);
     }
 
     /**
@@ -99,34 +165,122 @@
         return address;
     }
 
+    /**
+     * Helper for computing the mask value from CIDR.
+     *
+     * @return an integer bitmask
+     */
+    private int mask() {
+        int shift = MAX_INET_MASK - this.netmask;
+        return ((Integer.MAX_VALUE >>> (shift - 1)) << shift);
+    }
+
+    /**
+     * Returns the subnet mask in IpAddress form. The netmask value for
+     * the returned IpAddress is 0, as the address itself is a mask.
+     *
+     * @return the subnet mask
+     */
+    public IpAddress netmask() {
+        return new IpAddress(Version.INET, bytes(mask()));
+    }
+
+    /**
+     * Returns the network portion of this address as an IpAddress.
+     * The netmask of the returned IpAddress is the current mask. If this
+     * address doesn't have a mask, this returns an all-0 IpAddress.
+     *
+     * @return the network address or null
+     */
+    public IpAddress network() {
+        if (netmask == DEFAULT_MASK) {
+            return new IpAddress(version, ANY, DEFAULT_MASK);
+        }
+
+        byte [] net = new byte [4];
+        byte [] mask = bytes(mask());
+        for (int i = 0; i < INET_LEN; i++) {
+             net[i] = (byte) (octets[i] & mask[i]);
+        }
+        return new IpAddress(version, net, netmask);
+    }
+
+    /**
+     * Returns the host portion of the IPAddress, as an IPAddress.
+     * The netmask of the returned IpAddress is the current mask. If this
+     * address doesn't have a mask, this returns a copy of the current
+     * address.
+     *
+     * @return the host address
+     */
+    public IpAddress host() {
+        if (netmask == DEFAULT_MASK) {
+            new IpAddress(version, octets, netmask);
+        }
+
+        byte [] host = new byte [INET_LEN];
+        byte [] mask = bytes(mask());
+        for (int i = 0; i < INET_LEN; i++) {
+             host[i] = (byte) (octets[i] & ~mask[i]);
+        }
+        return new IpAddress(version, host, netmask);
+    }
+
     @Override
+    public int hashCode() {
+        final int prime = 31;
+        int result = 1;
+        result = prime * result + netmask;
+        result = prime * result + Arrays.hashCode(octets);
+        result = prime * result + ((version == null) ? 0 : version.hashCode());
+        return result;
+    }
+
+    @Override
+    public boolean equals(Object obj) {
+        if (this == obj) {
+            return true;
+        }
+        if (obj == null) {
+            return false;
+        }
+        if (getClass() != obj.getClass()) {
+            return false;
+        }
+        IpAddress other = (IpAddress) obj;
+        if (netmask != other.netmask) {
+            return false;
+        }
+        if (!Arrays.equals(octets, other.octets)) {
+            return false;
+        }
+        if (version != other.version) {
+            return false;
+        }
+        return true;
+    }
+
+    @Override
+    /*
+     * (non-Javadoc)
+     * format is "x.x.x.x" for non-masked (netmask 0) addresses,
+     * and "x.x.x.x/y" for masked addresses.
+     *
+     * @see java.lang.Object#toString()
+     */
     public String toString() {
         final StringBuilder builder = new StringBuilder();
         for (final byte b : this.octets) {
             if (builder.length() > 0) {
                 builder.append(".");
             }
-            builder.append(String.format("%d", b));
+            builder.append(String.format("%d", b & 0xff));
+        }
+        if (netmask != DEFAULT_MASK) {
+            builder.append("/");
+            builder.append(String.format("%d", netmask));
         }
         return builder.toString();
     }
 
-    @Override
-    public int hashCode() {
-        return Arrays.hashCode(octets);
-    }
-
-    @Override
-    public boolean equals(Object obj) {
-
-        if (obj instanceof IpAddress) {
-            IpAddress other = (IpAddress) obj;
-
-            if (this.version.equals(other.version)
-                    && (Arrays.equals(this.octets, other.octets))) {
-                return true;
-            }
-        }
-        return false;
-    }
 }
diff --git a/utils/misc/src/test/java/org/onlab/packet/IPAddressTest.java b/utils/misc/src/test/java/org/onlab/packet/IPAddressTest.java
index a0757cd..f1a7b0d 100644
--- a/utils/misc/src/test/java/org/onlab/packet/IPAddressTest.java
+++ b/utils/misc/src/test/java/org/onlab/packet/IPAddressTest.java
@@ -1,6 +1,7 @@
 package org.onlab.packet;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
 
 import java.util.Arrays;
 
@@ -11,33 +12,65 @@
 
 public class IPAddressTest {
 
-    private static final byte [] BYTES1 = new byte [] {0x0, 0x0, 0x0, 0xa};
-    private static final byte [] BYTES2 = new byte [] {0x0, 0x0, 0x0, 0xb};
-    private static final int INTVAL1 = 10;
-    private static final int INTVAL2 = 12;
-    private static final String STRVAL = "0.0.0.11";
+    private static final byte [] BYTES1 = new byte [] {0xa, 0x0, 0x0, 0xa};
+    private static final byte [] BYTES2 = new byte [] {0xa, 0x0, 0x0, 0xb};
+    private static final int INTVAL1 = 167772170;
+    private static final int INTVAL2 = 167772171;
+    private static final String STRVAL = "10.0.0.12";
+    private static final int MASK = 16;
 
     @Test
     public void testEquality() {
         IpAddress ip1 = IpAddress.valueOf(BYTES1);
-        IpAddress ip2 = IpAddress.valueOf(BYTES2);
-        IpAddress ip3 = IpAddress.valueOf(INTVAL1);
+        IpAddress ip2 = IpAddress.valueOf(INTVAL1);
+        IpAddress ip3 = IpAddress.valueOf(BYTES2);
         IpAddress ip4 = IpAddress.valueOf(INTVAL2);
         IpAddress ip5 = IpAddress.valueOf(STRVAL);
 
-        new EqualsTester().addEqualityGroup(ip1, ip3)
-        .addEqualityGroup(ip2, ip5)
-        .addEqualityGroup(ip4)
+        new EqualsTester().addEqualityGroup(ip1, ip2)
+        .addEqualityGroup(ip3, ip4)
+        .addEqualityGroup(ip5)
         .testEquals();
+
+        // string conversions
+        IpAddress ip6 = IpAddress.valueOf(BYTES1, MASK);
+        IpAddress ip7 = IpAddress.valueOf("10.0.0.10/16");
+        IpAddress ip8 = IpAddress.valueOf(new byte [] {0xa, 0x0, 0x0, 0xc});
+        assertEquals("incorrect address conversion", ip6, ip7);
+        assertEquals("incorrect address conversion", ip5, ip8);
     }
 
     @Test
     public void basics() {
-        IpAddress ip4 = IpAddress.valueOf(BYTES1);
-        assertEquals("incorrect IP Version", Version.INET, ip4.version());
-        assertEquals("faulty toOctets()", Arrays.equals(
-                new byte [] {0x0, 0x0, 0x0, 0xa}, ip4.toOctets()), true);
-        assertEquals("faulty toInt()", INTVAL1, ip4.toInt());
-        assertEquals("faulty toString()", "0.0.0.10", ip4.toString());
+        IpAddress ip1 = IpAddress.valueOf(BYTES1, MASK);
+        final byte [] bytes = new byte [] {0xa, 0x0, 0x0, 0xa};
+
+        //check fields
+        assertEquals("incorrect IP Version", Version.INET, ip1.version());
+        assertEquals("incorrect netmask", 16, ip1.netmask);
+        assertTrue("faulty toOctets()", Arrays.equals(bytes, ip1.toOctets()));
+        assertEquals("faulty toInt()", INTVAL1, ip1.toInt());
+        assertEquals("faulty toString()", "10.0.0.10/16", ip1.toString());
+    }
+
+    @Test
+    public void netmasks() {
+        // masked
+        IpAddress ip1 = IpAddress.valueOf(BYTES1, MASK);
+
+        IpAddress host = IpAddress.valueOf("0.0.0.10/16");
+        IpAddress network = IpAddress.valueOf("10.0.0.0/16");
+        assertEquals("incorrect host address", host, ip1.host());
+        assertEquals("incorrect network address", network, ip1.network());
+        assertEquals("incorrect netmask", "255.255.0.0", ip1.netmask().toString());
+
+        //unmasked
+        IpAddress ip2 = IpAddress.valueOf(BYTES1);
+        IpAddress umhost = IpAddress.valueOf("10.0.0.10/0");
+        IpAddress umnet = IpAddress.valueOf("0.0.0.0/0");
+        assertEquals("incorrect host address", umhost, ip2.host());
+        assertEquals("incorrect host address", umnet, ip2.network());
+        assertTrue("incorrect netmask",
+                Arrays.equals(IpAddress.ANY, ip2.netmask().toOctets()));
     }
 }