Ensure P4Runtime byte strings are padded to their bit width

The P4Runtime server may send canonical byte strings (i.e.,
non-padded byte strings).
In ONOS we ensure, in the codecs, that all byte strings are
padded to match the model (P4Info) bit width. In this way,
we provide read-write symmetry inside ONOS.
ONOS always pads byte strings when sending messages to the
P4Runtime server.
This patch doesn't enforce read-write symmetry between
P4Runtime client and server on the wire.

N.B.: the current padding implementation works ONLY when
using non-negative integer.

Change-Id: I9f8e43de015bd0929dd543d7688c8e71bf5fe98d
diff --git a/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/ActionCodec.java b/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/ActionCodec.java
index d0a285e..daf91f5 100644
--- a/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/ActionCodec.java
+++ b/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/ActionCodec.java
@@ -28,6 +28,8 @@
 import p4.v1.P4RuntimeOuterClass;
 
 import static java.lang.String.format;
+import static org.onlab.util.ImmutableByteSequence.copyAndFit;
+import static org.onlab.util.ImmutableByteSequence.copyFrom;
 import static org.onosproject.p4runtime.ctl.codec.Utils.assertSize;
 
 /**
@@ -65,7 +67,7 @@
     protected PiAction decode(
             P4RuntimeOuterClass.Action message, Object ignored,
             PiPipeconf pipeconf, P4InfoBrowser browser)
-            throws P4InfoBrowser.NotFoundException {
+            throws P4InfoBrowser.NotFoundException, CodecException {
         final P4InfoBrowser.EntityBrowser<P4InfoOuterClass.Action.Param> paramInfo =
                 browser.actionParams(message.getActionId());
         final String actionName = browser.actions()
@@ -77,11 +79,17 @@
             final P4InfoOuterClass.Action.Param actionParam = paramInfo.getById(p.getParamId());
             final ImmutableByteSequence value;
             if (browser.isTypeString(actionParam.getTypeName())) {
-                value = ImmutableByteSequence.copyFrom(new String(p.getValue().toByteArray()));
+                value = copyFrom(new String(p.getValue().toByteArray()));
             } else {
-                value = ImmutableByteSequence.copyFrom(p.getValue().toByteArray());
+                try {
+                    value = copyAndFit(p.getValue().asReadOnlyByteBuffer(),
+                                       actionParam.getBitwidth());
+                } catch (ImmutableByteSequence.ByteSequenceTrimException e) {
+                    throw new CodecException(e.getMessage());
+                }
             }
-            builder.withParameter(new PiActionParam(PiActionParamId.of(actionParam.getName()), value));
+            builder.withParameter(new PiActionParam(
+                    PiActionParamId.of(actionParam.getName()), value));
         }
         return builder.build();
     }
diff --git a/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/FieldMatchCodec.java b/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/FieldMatchCodec.java
index 6b01546..48baa17 100644
--- a/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/FieldMatchCodec.java
+++ b/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/FieldMatchCodec.java
@@ -31,6 +31,7 @@
 import p4.v1.P4RuntimeOuterClass;
 
 import static java.lang.String.format;
+import static org.onlab.util.ImmutableByteSequence.copyAndFit;
 import static org.onlab.util.ImmutableByteSequence.copyFrom;
 import static org.onosproject.p4runtime.ctl.codec.Utils.assertPrefixLen;
 import static org.onosproject.p4runtime.ctl.codec.Utils.assertSize;
@@ -155,48 +156,66 @@
         final P4InfoOuterClass.MatchField matchField =
                 browser.matchFields(tablePreamble.getId())
                         .getById(message.getFieldId());
+        final int fieldBitwidth = matchField.getBitwidth();
         final PiMatchFieldId headerFieldId = PiMatchFieldId.of(matchField.getName());
         final boolean isSdnString = browser.isTypeString(matchField.getTypeName());
 
-        P4RuntimeOuterClass.FieldMatch.FieldMatchTypeCase typeCase = message.getFieldMatchTypeCase();
-
-        switch (typeCase) {
-            case EXACT:
-                P4RuntimeOuterClass.FieldMatch.Exact exactFieldMatch = message.getExact();
-                ImmutableByteSequence exactValue;
-                if (isSdnString) {
-                    exactValue = copyFrom(new String(exactFieldMatch.getValue().toByteArray()));
-                } else {
-                    exactValue = copyFrom(exactFieldMatch.getValue().asReadOnlyByteBuffer());
-                }
-                return new PiExactFieldMatch(headerFieldId, exactValue);
-            case TERNARY:
-                P4RuntimeOuterClass.FieldMatch.Ternary ternaryFieldMatch = message.getTernary();
-                ImmutableByteSequence ternaryValue = copyFrom(ternaryFieldMatch.getValue().asReadOnlyByteBuffer());
-                ImmutableByteSequence ternaryMask = copyFrom(ternaryFieldMatch.getMask().asReadOnlyByteBuffer());
-                return new PiTernaryFieldMatch(headerFieldId, ternaryValue, ternaryMask);
-            case LPM:
-                P4RuntimeOuterClass.FieldMatch.LPM lpmFieldMatch = message.getLpm();
-                ImmutableByteSequence lpmValue = copyFrom(lpmFieldMatch.getValue().asReadOnlyByteBuffer());
-                int lpmPrefixLen = lpmFieldMatch.getPrefixLen();
-                return new PiLpmFieldMatch(headerFieldId, lpmValue, lpmPrefixLen);
-            case RANGE:
-                P4RuntimeOuterClass.FieldMatch.Range rangeFieldMatch = message.getRange();
-                ImmutableByteSequence rangeHighValue = copyFrom(rangeFieldMatch.getHigh().asReadOnlyByteBuffer());
-                ImmutableByteSequence rangeLowValue = copyFrom(rangeFieldMatch.getLow().asReadOnlyByteBuffer());
-                return new PiRangeFieldMatch(headerFieldId, rangeLowValue, rangeHighValue);
-            case OPTIONAL:
-                P4RuntimeOuterClass.FieldMatch.Optional optionalFieldMatch = message.getOptional();
-                ImmutableByteSequence optionalValue;
-                if (isSdnString) {
-                    optionalValue = copyFrom(new String(optionalFieldMatch.getValue().toByteArray()));
-                } else {
-                    optionalValue = copyFrom(optionalFieldMatch.getValue().asReadOnlyByteBuffer());
-                }
-                return new PiOptionalFieldMatch(headerFieldId, optionalValue);
-            default:
-                throw new CodecException(format(
-                        "Decoding of field match type '%s' not implemented", typeCase.name()));
+        final P4RuntimeOuterClass.FieldMatch.FieldMatchTypeCase typeCase = message.getFieldMatchTypeCase();
+        try {
+            switch (typeCase) {
+                case EXACT:
+                    P4RuntimeOuterClass.FieldMatch.Exact exactFieldMatch = message.getExact();
+                    final ImmutableByteSequence exactValue;
+                    if (isSdnString) {
+                        exactValue = copyFrom(new String(exactFieldMatch.getValue().toByteArray()));
+                    } else {
+                        exactValue = copyAndFit(
+                                exactFieldMatch.getValue().asReadOnlyByteBuffer(),
+                                fieldBitwidth);
+                    }
+                    return new PiExactFieldMatch(headerFieldId, exactValue);
+                case TERNARY:
+                    P4RuntimeOuterClass.FieldMatch.Ternary ternaryFieldMatch = message.getTernary();
+                    ImmutableByteSequence ternaryValue = copyAndFit(
+                            ternaryFieldMatch.getValue().asReadOnlyByteBuffer(),
+                            fieldBitwidth);
+                    ImmutableByteSequence ternaryMask = copyAndFit(
+                            ternaryFieldMatch.getMask().asReadOnlyByteBuffer(),
+                            fieldBitwidth);
+                    return new PiTernaryFieldMatch(headerFieldId, ternaryValue, ternaryMask);
+                case LPM:
+                    P4RuntimeOuterClass.FieldMatch.LPM lpmFieldMatch = message.getLpm();
+                    ImmutableByteSequence lpmValue = copyAndFit(
+                            lpmFieldMatch.getValue().asReadOnlyByteBuffer(),
+                            fieldBitwidth);
+                    int lpmPrefixLen = lpmFieldMatch.getPrefixLen();
+                    return new PiLpmFieldMatch(headerFieldId, lpmValue, lpmPrefixLen);
+                case RANGE:
+                    P4RuntimeOuterClass.FieldMatch.Range rangeFieldMatch = message.getRange();
+                    ImmutableByteSequence rangeHighValue = copyAndFit(
+                            rangeFieldMatch.getHigh().asReadOnlyByteBuffer(),
+                            fieldBitwidth);
+                    ImmutableByteSequence rangeLowValue = copyAndFit(
+                            rangeFieldMatch.getLow().asReadOnlyByteBuffer(),
+                            fieldBitwidth);
+                    return new PiRangeFieldMatch(headerFieldId, rangeLowValue, rangeHighValue);
+                case OPTIONAL:
+                    P4RuntimeOuterClass.FieldMatch.Optional optionalFieldMatch = message.getOptional();
+                    final ImmutableByteSequence optionalValue;
+                    if (isSdnString) {
+                        optionalValue = copyFrom(new String(optionalFieldMatch.getValue().toByteArray()));
+                    } else {
+                        optionalValue = copyAndFit(
+                                optionalFieldMatch.getValue().asReadOnlyByteBuffer(),
+                                fieldBitwidth);
+                    }
+                    return new PiOptionalFieldMatch(headerFieldId, optionalValue);
+                default:
+                    throw new CodecException(format(
+                            "Decoding of field match type '%s' not implemented", typeCase.name()));
+            }
+        } catch (ImmutableByteSequence.ByteSequenceTrimException e) {
+            throw new CodecException(e.getMessage());
         }
     }
 }
diff --git a/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/PacketMetadataCodec.java b/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/PacketMetadataCodec.java
index 876f84c..ab29e62 100644
--- a/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/PacketMetadataCodec.java
+++ b/protocols/p4runtime/utils/src/main/java/org/onosproject/p4runtime/ctl/codec/PacketMetadataCodec.java
@@ -25,6 +25,7 @@
 import p4.config.v1.P4InfoOuterClass;
 import p4.v1.P4RuntimeOuterClass;
 
+import static org.onlab.util.ImmutableByteSequence.copyAndFit;
 import static org.onlab.util.ImmutableByteSequence.copyFrom;
 
 /**
@@ -54,18 +55,23 @@
             P4RuntimeOuterClass.PacketMetadata message,
             P4InfoOuterClass.Preamble ctrlPktMetaPreamble,
             PiPipeconf pipeconf, P4InfoBrowser browser)
-            throws P4InfoBrowser.NotFoundException {
-        final P4InfoOuterClass.ControllerPacketMetadata.Metadata packetMetadata =
+            throws P4InfoBrowser.NotFoundException, CodecException {
+        final P4InfoOuterClass.ControllerPacketMetadata.Metadata pktMeta =
                 browser.packetMetadatas(ctrlPktMetaPreamble.getId())
-                .getById(message.getMetadataId());
+                        .getById(message.getMetadataId());
         final ImmutableByteSequence value;
-        if (browser.isTypeString(packetMetadata.getTypeName())) {
+        if (browser.isTypeString(pktMeta.getTypeName())) {
             value = copyFrom(new String(message.getValue().toByteArray()));
         } else {
-            value = copyFrom(message.getValue().asReadOnlyByteBuffer());
+            try {
+                value = copyAndFit(message.getValue().asReadOnlyByteBuffer(),
+                                   pktMeta.getBitwidth());
+            } catch (ImmutableByteSequence.ByteSequenceTrimException e) {
+                throw new CodecException(e.getMessage());
+            }
         }
         return PiPacketMetadata.builder()
-                .withId(PiPacketMetadataId.of(packetMetadata.getName()))
+                .withId(PiPacketMetadataId.of(pktMeta.getName()))
                 .withValue(value)
                 .build();
     }
diff --git a/utils/misc/src/main/java/org/onlab/util/ImmutableByteSequence.java b/utils/misc/src/main/java/org/onlab/util/ImmutableByteSequence.java
index 1710d29..32e9d0b 100644
--- a/utils/misc/src/main/java/org/onlab/util/ImmutableByteSequence.java
+++ b/utils/misc/src/main/java/org/onlab/util/ImmutableByteSequence.java
@@ -190,6 +190,53 @@
     }
 
     /**
+     * Creates a new immutable byte sequence while trimming or expanding the
+     * content of the given byte buffer to fit the given bit-width. Calling this
+     * method has the same behavior as
+     * {@code ImmutableByteSequence.copyFrom(original).fit(bitWidth)}.
+     *
+     * @param original a byte buffer value
+     * @param bitWidth a non-zero positive integer
+     * @return a new immutable byte sequence
+     * @throws ByteSequenceTrimException if the byte buffer cannot be fitted
+     */
+    public static ImmutableByteSequence copyAndFit(ByteBuffer original, int bitWidth)
+            throws ByteSequenceTrimException {
+        checkArgument(original != null && original.capacity() > 0,
+                      "Cannot copy from an empty or null byte buffer");
+        checkArgument(bitWidth > 0,
+                      "bit-width must be a non-zero positive integer");
+        if (original.order() == ByteOrder.LITTLE_ENDIAN) {
+            // FIXME: this can be improved, e.g. read bytes in reverse order from original
+            byte[] newBytes = new byte[original.capacity()];
+            original.get(newBytes);
+            reverse(newBytes);
+            return internalCopyAndFit(ByteBuffer.wrap(newBytes), bitWidth);
+        } else {
+            return internalCopyAndFit(original.duplicate(), bitWidth);
+        }
+    }
+
+    private static ImmutableByteSequence internalCopyAndFit(ByteBuffer byteBuf, int bitWidth)
+            throws ByteSequenceTrimException {
+        final int byteWidth = (bitWidth + 7) / 8;
+        final ByteBuffer newByteBuffer = ByteBuffer.allocate(byteWidth);
+        byteBuf.rewind();
+        if (byteWidth >= byteBuf.capacity()) {
+            newByteBuffer.position(byteWidth - byteBuf.capacity());
+            newByteBuffer.put(byteBuf);
+        } else {
+            for (int i = 0; i < byteBuf.capacity() - byteWidth; i++) {
+                if (byteBuf.get(i) != 0) {
+                    throw new ByteSequenceTrimException(byteBuf, bitWidth);
+                }
+            }
+            newByteBuffer.put(byteBuf.position(byteBuf.capacity() - byteWidth));
+        }
+        return new ImmutableByteSequence(newByteBuffer);
+    }
+
+    /**
      * Creates a new byte sequence of the given size where all bits are 0.
      *
      * @param size number of bytes
@@ -500,12 +547,17 @@
     }
 
     /**
-     * Signals that a byte sequence cannot be trimmed.
+     * Signals a trim exception during byte sequence creation.
      */
     public static class ByteSequenceTrimException extends Exception {
         ByteSequenceTrimException(ImmutableByteSequence original, int bitWidth) {
             super(format("cannot trim %s into a %d bits long value",
                          original, bitWidth));
         }
+
+        ByteSequenceTrimException(ByteBuffer original, int bitWidth) {
+            super(format("cannot trim %s (ByteBuffer) into a %d bits long value",
+                         original, bitWidth));
+        }
     }
 }
diff --git a/utils/misc/src/test/java/org/onlab/util/ImmutableByteSequenceTest.java b/utils/misc/src/test/java/org/onlab/util/ImmutableByteSequenceTest.java
index 7e74328..c83c6b2 100644
--- a/utils/misc/src/test/java/org/onlab/util/ImmutableByteSequenceTest.java
+++ b/utils/misc/src/test/java/org/onlab/util/ImmutableByteSequenceTest.java
@@ -25,6 +25,7 @@
 
 import java.nio.ByteBuffer;
 import java.nio.ByteOrder;
+import java.util.Arrays;
 import java.util.Random;
 
 import static java.lang.Integer.max;
@@ -81,6 +82,107 @@
     }
 
     @Test
+    public void testCopyAndFit() throws Exception {
+        int originalByteWidth = 3;
+        int paddedByteWidth = 4;
+        int trimmedByteWidth = 2;
+        int indexFirstNonZeroByte = 1;
+
+        byte byteValue = (byte) 1;
+        byte[] arrayValue = new byte[originalByteWidth];
+        arrayValue[indexFirstNonZeroByte] = byteValue;
+        ByteBuffer bufferValue = ByteBuffer.allocate(originalByteWidth).put(arrayValue);
+
+        ImmutableByteSequence bsBuffer = ImmutableByteSequence.copyAndFit(
+                bufferValue, originalByteWidth * 8);
+        ImmutableByteSequence bsBufferTrimmed = ImmutableByteSequence.copyAndFit(
+                bufferValue, trimmedByteWidth * 8);
+        ImmutableByteSequence bsBufferPadded = ImmutableByteSequence.copyAndFit(
+                bufferValue, paddedByteWidth * 8);
+
+        assertThat("byte sequence of the byte buffer must be 3 bytes long",
+                   bsBuffer.size(), is(equalTo(originalByteWidth)));
+        assertThat("byte sequence of the byte buffer must be 3 bytes long",
+                   bsBufferTrimmed.size(), is(equalTo(trimmedByteWidth)));
+        assertThat("byte sequence of the byte buffer must be 3 bytes long",
+                   bsBufferPadded.size(), is(equalTo(paddedByteWidth)));
+
+        String errStr = "incorrect byte sequence value";
+
+        assertThat(errStr, bsBuffer.asArray()[indexFirstNonZeroByte], is(equalTo(byteValue)));
+        assertThat(errStr, bsBufferTrimmed.asArray()[indexFirstNonZeroByte - 1], is(equalTo(byteValue)));
+        assertThat(errStr, bsBufferPadded.asArray()[indexFirstNonZeroByte + 1], is(equalTo(byteValue)));
+        assertThat(errStr, bsBufferPadded.asArray()[paddedByteWidth - 1], is(equalTo((byte) 0x00)));
+    }
+
+    @Test
+    public void testCopyAndFitEndianness() throws Exception {
+        int originalByteWidth = 4;
+        int indexByteNonZeroBig = 1;
+        int indexByteNonZeroLittle = 2;
+        byte byteValue = (byte) 1;
+
+        ByteBuffer bbBigEndian = ByteBuffer
+                .allocate(originalByteWidth)
+                .order(ByteOrder.BIG_ENDIAN);
+        bbBigEndian.put(indexByteNonZeroBig, byteValue);
+        ImmutableByteSequence bsBufferCopyBigEndian =
+                ImmutableByteSequence.copyAndFit(bbBigEndian, originalByteWidth * 8);
+
+        ByteBuffer bbLittleEndian = ByteBuffer
+                .allocate(originalByteWidth)
+                .order(ByteOrder.LITTLE_ENDIAN);
+        bbLittleEndian.put(indexByteNonZeroLittle, byteValue);
+        ImmutableByteSequence bsBufferCopyLittleEndian =
+                ImmutableByteSequence.copyAndFit(bbLittleEndian, originalByteWidth * 8);
+
+        // creates a new sequence from primitive type
+        byte[] arrayValue = new byte[originalByteWidth];
+        arrayValue[indexByteNonZeroBig] = byteValue;
+        ImmutableByteSequence bsArrayCopy =
+                ImmutableByteSequence.copyFrom(arrayValue);
+
+        new EqualsTester()
+                // big-endian byte array cannot be equal to little-endian array
+                .addEqualityGroup(bbBigEndian.array())
+                .addEqualityGroup(bbLittleEndian.array())
+                // all byte sequences must be equal
+                .addEqualityGroup(bsBufferCopyBigEndian,
+                                  bsBufferCopyLittleEndian,
+                                  bsArrayCopy)
+                // byte buffer views of all sequences must be equal
+                .addEqualityGroup(bsBufferCopyBigEndian.asReadOnlyBuffer(),
+                                  bsBufferCopyLittleEndian.asReadOnlyBuffer(),
+                                  bsArrayCopy.asReadOnlyBuffer())
+                // byte buffer orders of all sequences must be ByteOrder.BIG_ENDIAN
+                .addEqualityGroup(bsBufferCopyBigEndian.asReadOnlyBuffer().order(),
+                                  bsBufferCopyLittleEndian.asReadOnlyBuffer().order(),
+                                  bsArrayCopy.asReadOnlyBuffer().order(),
+                                  ByteOrder.BIG_ENDIAN)
+                .testEquals();
+    }
+
+    @Test
+    public void testIllegalCopyAndFit() throws Exception {
+        int originalByteWidth = 3;
+        int trimmedByteWidth = 1;
+        int indexFirstNonZeroByte = 1;
+
+        byte byteValue = (byte) 1;
+        byte[] arrayValue = new byte[originalByteWidth];
+        arrayValue[indexFirstNonZeroByte] = byteValue;
+        ByteBuffer bufferValue = ByteBuffer.allocate(originalByteWidth).put(arrayValue);
+
+        try {
+            ImmutableByteSequence.copyAndFit(bufferValue, trimmedByteWidth * 8);
+            Assert.fail(format("Expect ByteSequenceTrimException due to value = %s and bitWidth %d",
+                               Arrays.toString(arrayValue), trimmedByteWidth * 8));
+        } catch (ImmutableByteSequence.ByteSequenceTrimException e) {
+            // We expect this.
+        }
+    }
+
+    @Test
     public void testEndianness() throws Exception {
 
         long longValue = RandomUtils.nextLong();