[ONOS-6075] Rewrite Copycat Transport
- Ensure connection IDs are globally unique
- Ensure connections are closed on each side when close() is called
- Add Transport unit tests

Change-Id: Ia848b075d4030ce74293ecc57fea983693cee265
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
index 17fa742..45cc529 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
@@ -15,73 +15,59 @@
  */
 package org.onosproject.store.primitives.impl;
 
-import static com.google.common.base.Preconditions.checkNotNull;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Maps;
+import io.atomix.catalyst.transport.Address;
+import io.atomix.catalyst.transport.Client;
+import io.atomix.catalyst.transport.Server;
+import io.atomix.catalyst.transport.Transport;
+import org.onlab.packet.IpAddress;
+import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingService;
 
 import java.net.InetAddress;
 import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.Map;
 
-import org.onlab.packet.IpAddress;
-import org.onosproject.cluster.PartitionId;
-import org.onosproject.store.cluster.messaging.Endpoint;
-import org.onosproject.store.cluster.messaging.MessagingService;
-
-import com.google.common.base.Throwables;
-import com.google.common.collect.Maps;
-
-import io.atomix.catalyst.transport.Address;
-import io.atomix.catalyst.transport.Client;
-import io.atomix.catalyst.transport.Server;
-import io.atomix.catalyst.transport.Transport;
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkNotNull;
 
 /**
- * Custom {@link Transport transport} for Copycat interactions
- * built on top of {@link MessagingService}.
- *
- * @see CopycatTransportServer
- * @see CopycatTransportClient
+ * Copycat transport implementation built on {@link MessagingService}.
  */
 public class CopycatTransport implements Transport {
-
-    /**
-     * Transport Mode.
-     */
-    public enum Mode {
-        /**
-         * Signifies transport for client {@literal ->} server interaction.
-         */
-        CLIENT,
-
-        /**
-         * Signified transport for server {@literal ->} server interaction.
-         */
-        SERVER
-    }
-
-    private final Mode mode;
     private final PartitionId partitionId;
     private final MessagingService messagingService;
     private static final Map<Address, Endpoint> EP_LOOKUP_CACHE = Maps.newConcurrentMap();
     private static final Map<Endpoint, Address> ADDRESS_LOOKUP_CACHE = Maps.newConcurrentMap();
 
-    public CopycatTransport(Mode mode, PartitionId partitionId, MessagingService messagingService) {
-        this.mode = checkNotNull(mode);
-        this.partitionId = checkNotNull(partitionId);
-        this.messagingService = checkNotNull(messagingService);
+    static final byte MESSAGE = 0x01;
+    static final byte CONNECT = 0x02;
+    static final byte CLOSE = 0x03;
+
+    static final byte SUCCESS = 0x01;
+    static final byte FAILURE = 0x02;
+
+    public CopycatTransport(PartitionId partitionId, MessagingService messagingService) {
+        this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+        this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
     }
 
     @Override
     public Client client() {
-        return new CopycatTransportClient(partitionId,
-                                          messagingService,
-                                          mode);
+        return new CopycatTransportClient(partitionId, messagingService);
     }
 
     @Override
     public Server server() {
-        return new CopycatTransportServer(partitionId,
-                                          messagingService);
+        return new CopycatTransportServer(partitionId, messagingService);
+    }
+
+    @Override
+    public String toString() {
+        return toStringHelper(this).toString();
     }
 
     /**
@@ -89,7 +75,7 @@
      * @param address address
      * @return end point
      */
-    public static Endpoint toEndpoint(Address address) {
+    static Endpoint toEndpoint(Address address) {
         return EP_LOOKUP_CACHE.computeIfAbsent(address, a -> {
             try {
                 return new Endpoint(IpAddress.valueOf(InetAddress.getByName(a.host())), a.port());
@@ -105,7 +91,7 @@
      * @param endpoint end point
      * @return address
      */
-    public static Address toAddress(Endpoint endpoint) {
+    static Address toAddress(Endpoint endpoint) {
         return ADDRESS_LOOKUP_CACHE.computeIfAbsent(endpoint, ep -> {
             try {
                 InetAddress host = InetAddress.getByAddress(endpoint.host().toOctets());
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportClient.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportClient.java
index 7a1cecf..3567338 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportClient.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportClient.java
@@ -15,53 +15,94 @@
  */
 package org.onosproject.store.primitives.impl;
 
-import static com.google.common.base.Preconditions.checkNotNull;
-
-import java.util.Set;
-import java.util.concurrent.CompletableFuture;
-
-import org.apache.commons.lang.math.RandomUtils;
-import org.onosproject.cluster.PartitionId;
-import org.onosproject.store.cluster.messaging.MessagingService;
-
+import com.google.common.base.Throwables;
 import com.google.common.collect.Sets;
-
+import io.atomix.catalyst.concurrent.ThreadContext;
 import io.atomix.catalyst.transport.Address;
 import io.atomix.catalyst.transport.Client;
 import io.atomix.catalyst.transport.Connection;
-import io.atomix.catalyst.concurrent.ThreadContext;
+import io.atomix.catalyst.transport.TransportException;
+import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingException;
+import org.onosproject.store.cluster.messaging.MessagingService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.net.ConnectException;
+import java.nio.ByteBuffer;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static org.onosproject.store.primitives.impl.CopycatTransport.CONNECT;
+import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
 
 /**
- * {@link Client} implementation for {@link CopycatTransport}.
+ * Copycat transport client implementation.
  */
 public class CopycatTransportClient implements Client {
-
+    private final Logger log = LoggerFactory.getLogger(getClass());
     private final PartitionId partitionId;
+    private final String serverSubject;
     private final MessagingService messagingService;
-    private final CopycatTransport.Mode mode;
     private final Set<CopycatTransportConnection> connections = Sets.newConcurrentHashSet();
 
-    CopycatTransportClient(PartitionId partitionId, MessagingService messagingService, CopycatTransport.Mode mode) {
-        this.partitionId = checkNotNull(partitionId);
-        this.messagingService = checkNotNull(messagingService);
-        this.mode = checkNotNull(mode);
+    public CopycatTransportClient(PartitionId partitionId, MessagingService messagingService) {
+        this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+        this.serverSubject = String.format("onos-copycat-%s", partitionId);
+        this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
     }
 
     @Override
-    public CompletableFuture<Connection> connect(Address remoteAddress) {
+    public CompletableFuture<Connection> connect(Address address) {
+        CompletableFuture<Connection> future = new CompletableFuture<>();
         ThreadContext context = ThreadContext.currentContextOrThrow();
-        CopycatTransportConnection connection = new CopycatTransportConnection(
-                nextConnectionId(),
-                CopycatTransport.Mode.CLIENT,
-                partitionId,
-                remoteAddress,
-                messagingService,
-                context);
-        if (mode == CopycatTransport.Mode.CLIENT) {
-            connection.setBidirectional();
-        }
-        connections.add(connection);
-        return CompletableFuture.supplyAsync(() -> connection, context.executor());
+
+        Endpoint endpoint = CopycatTransport.toEndpoint(address);
+
+        log.debug("Connecting to {}", address);
+
+        ByteBuffer requestBuffer = ByteBuffer.allocate(1);
+        requestBuffer.put(CONNECT);
+
+        // Send a connect request to the server to get a unique connection ID.
+        messagingService.sendAndReceive(endpoint, serverSubject, requestBuffer.array(), context.executor())
+                .whenComplete((payload, error) -> {
+                    Throwable wrappedError = error;
+                    if (error != null) {
+                        Throwable rootCause = Throwables.getRootCause(error);
+                        if (MessagingException.class.isAssignableFrom(rootCause.getClass())) {
+                            wrappedError = new TransportException(error);
+                        }
+                        log.warn("Connection to {} failed! Reason: {}", address, wrappedError);
+                        future.completeExceptionally(wrappedError);
+                    } else {
+                        // If the connection is successful, the server will send back a
+                        // connection ID indicating where to send messages for the connection.
+                        ByteBuffer responseBuffer = ByteBuffer.wrap(payload);
+                        if (responseBuffer.get() == SUCCESS) {
+                            long connectionId = responseBuffer.getLong();
+                            CopycatTransportConnection connection = new CopycatTransportConnection(
+                                    connectionId,
+                                    CopycatTransportConnection.Mode.CLIENT,
+                                    partitionId,
+                                    endpoint,
+                                    messagingService,
+                                    context);
+                            connection.closeListener(connections::remove);
+                            connections.add(connection);
+                            future.complete(connection);
+                            log.debug("Created connection {}-{} to {}", partitionId, connectionId, address);
+                        } else {
+                            log.warn("Connection to {} failed!");
+                            future.completeExceptionally(new ConnectException());
+                        }
+                    }
+                });
+        return future;
+
     }
 
     @Override
@@ -69,7 +110,11 @@
         return CompletableFuture.allOf(connections.stream().map(Connection::close).toArray(CompletableFuture[]::new));
     }
 
-    private long nextConnectionId() {
-        return RandomUtils.nextLong();
+    @Override
+    public String toString() {
+        return toStringHelper(this)
+                .add("partitionId", partitionId)
+                .toString();
     }
- }
+}
+
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
index f2752cd..fadb6dd 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
@@ -15,24 +15,15 @@
  */
 package org.onosproject.store.primitives.impl;
 
-import com.google.common.base.MoreObjects;
-import com.google.common.base.Throwables;
-import com.google.common.collect.Maps;
+
 import io.atomix.catalyst.concurrent.Listener;
 import io.atomix.catalyst.concurrent.Listeners;
 import io.atomix.catalyst.concurrent.ThreadContext;
 import io.atomix.catalyst.serializer.SerializationException;
-import io.atomix.catalyst.transport.Address;
 import io.atomix.catalyst.transport.Connection;
 import io.atomix.catalyst.transport.MessageHandler;
 import io.atomix.catalyst.transport.TransportException;
-import io.atomix.catalyst.util.Assert;
 import io.atomix.catalyst.util.reference.ReferenceCounted;
-import org.apache.commons.io.IOUtils;
-import org.onlab.util.Tools;
-import org.onosproject.cluster.PartitionId;
-import org.onosproject.store.cluster.messaging.MessagingException;
-import org.onosproject.store.cluster.messaging.MessagingService;
 
 import java.io.ByteArrayInputStream;
 import java.io.ByteArrayOutputStream;
@@ -40,79 +31,76 @@
 import java.io.DataOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
+import java.nio.ByteBuffer;
 import java.util.Map;
-import java.util.Objects;
 import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.function.Consumer;
 
+import org.apache.commons.io.IOUtils;
+import org.onlab.util.Tools;
+import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingException;
+import org.onosproject.store.cluster.messaging.MessagingService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Throwables;
+
 import static com.google.common.base.Preconditions.checkNotNull;
+import static org.onosproject.store.primitives.impl.CopycatTransport.CLOSE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.FAILURE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.MESSAGE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
+
 
 /**
- * {@link Connection} implementation for CopycatTransport.
+ * Base Copycat Transport connection.
  */
 public class CopycatTransportConnection implements Connection {
-
+    private final Logger log = LoggerFactory.getLogger(getClass());
+    private final long connectionId;
+    private final String localSubject;
+    private final String remoteSubject;
+    private final PartitionId partitionId;
+    private final Endpoint endpoint;
+    private final MessagingService messagingService;
+    private final ThreadContext context;
+    private final Map<Class, InternalHandler> handlers = new ConcurrentHashMap<>();
     private final Listeners<Throwable> exceptionListeners = new Listeners<>();
     private final Listeners<Connection> closeListeners = new Listeners<>();
 
-    static final byte SUCCESS = 0x03;
-    static final byte FAILURE = 0x04;
-
-    private final long connectionId;
-    private final CopycatTransport.Mode mode;
-    private final Address remoteAddress;
-    private final MessagingService messagingService;
-    private final String outboundMessageSubject;
-    private final String inboundMessageSubject;
-    private final ThreadContext context;
-    private final Map<Class<?>, InternalHandler> handlers = Maps.newConcurrentMap();
-
-    CopycatTransportConnection(long connectionId,
-            CopycatTransport.Mode mode,
+    CopycatTransportConnection(
+            long connectionId,
+            Mode mode,
             PartitionId partitionId,
-            Address address,
+            Endpoint endpoint,
             MessagingService messagingService,
             ThreadContext context) {
         this.connectionId = connectionId;
-        this.mode = checkNotNull(mode);
-        this.remoteAddress = checkNotNull(address);
-        this.messagingService = checkNotNull(messagingService);
-        if (mode == CopycatTransport.Mode.CLIENT) {
-            this.outboundMessageSubject = String.format("onos-copycat-%s", partitionId);
-            this.inboundMessageSubject = String.format("onos-copycat-%s-%d", partitionId, connectionId);
-        } else {
-            this.outboundMessageSubject = String.format("onos-copycat-%s-%d", partitionId, connectionId);
-            this.inboundMessageSubject = String.format("onos-copycat-%s", partitionId);
-        }
-        this.context = checkNotNull(context);
-    }
-
-    public void setBidirectional() {
-        messagingService.registerHandler(inboundMessageSubject, (sender, payload) -> {
-            try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
-                if (input.readLong() !=  connectionId) {
-                    throw new IllegalStateException("Invalid connection Id");
-                }
-                return handle(IOUtils.toByteArray(input));
-            } catch (IOException e) {
-                Throwables.propagate(e);
-                return null;
-            }
-        });
+        this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+        this.localSubject = mode.getLocalSubject(partitionId, connectionId);
+        this.remoteSubject = mode.getRemoteSubject(partitionId, connectionId);
+        this.endpoint = checkNotNull(endpoint, "endpoint cannot be null");
+        this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
+        this.context = checkNotNull(context, "context cannot be null");
+        messagingService.registerHandler(localSubject, this::handle);
     }
 
     @Override
     public <T, U> CompletableFuture<U> send(T message) {
         ThreadContext context = ThreadContext.currentContextOrThrow();
-        CompletableFuture<U> result = new CompletableFuture<>();
+        CompletableFuture<U> future = new CompletableFuture<>();
         try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
-            new DataOutputStream(baos).writeLong(connectionId);
+            DataOutputStream dos = new DataOutputStream(baos);
+            dos.writeByte(MESSAGE);
             context.serializer().writeObject(message, baos);
             if (message instanceof ReferenceCounted) {
                 ((ReferenceCounted<?>) message).release();
             }
-            messagingService.sendAndReceive(CopycatTransport.toEndpoint(remoteAddress),
-                                            outboundMessageSubject,
+            messagingService.sendAndReceive(endpoint,
+                                            remoteSubject,
                                             baos.toByteArray(),
                                             context.executor())
                     .whenComplete((r, e) -> {
@@ -123,20 +111,23 @@
                                 wrappedError = new TransportException(e);
                             }
                         }
-                        handleResponse(r, wrappedError, result, context);
+                        handleResponse(r, wrappedError, future);
                     });
         } catch (SerializationException | IOException e) {
-            result.completeExceptionally(e);
+            future.completeExceptionally(e);
         }
-        return result;
+        return future;
     }
 
-    private <T> void handleResponse(byte[] response,
-                                    Throwable error,
-                                    CompletableFuture<T> future,
-                                    ThreadContext context) {
+    /**
+     * Handles a response received from the other side of the connection.
+     */
+    private <T> void handleResponse(
+            byte[] response,
+            Throwable error,
+            CompletableFuture<T> future) {
         if (error != null) {
-            context.execute(() -> future.completeExceptionally(error));
+            future.completeExceptionally(error);
             return;
         }
         checkNotNull(response);
@@ -145,36 +136,54 @@
             byte status = (byte) input.read();
             if (status == FAILURE) {
                 Throwable t = context.serializer().readObject(input);
-                context.execute(() -> future.completeExceptionally(t));
+                future.completeExceptionally(t);
             } else {
-                context.execute(() -> {
-                    try {
-                        future.complete(context.serializer().readObject(input));
-                    } catch (SerializationException e) {
-                        future.completeExceptionally(e);
-                    }
-                });
+                try {
+                    future.complete(context.serializer().readObject(input));
+                } catch (SerializationException e) {
+                    future.completeExceptionally(e);
+                }
             }
         } catch (IOException e) {
-            context.execute(() -> future.completeExceptionally(e));
+            future.completeExceptionally(e);
         }
     }
 
-    @Override
-    public <T, U> Connection handler(Class<T> type, MessageHandler<T, U> handler) {
-        Assert.notNull(type, "type");
-        handlers.put(type, new InternalHandler(handler, ThreadContext.currentContextOrThrow()));
-        return null;
+    /**
+     * Handles a message sent to the connection.
+     */
+    private CompletableFuture<byte[]> handle(Endpoint sender, byte[] payload) {
+        try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
+            byte type = input.readByte();
+            switch (type) {
+                case MESSAGE:
+                    return handleMessage(IOUtils.toByteArray(input));
+                case CLOSE:
+                    return handleClose();
+                default:
+                    throw new IllegalStateException("Invalid message type");
+            }
+        } catch (IOException e) {
+            Throwables.propagate(e);
+            return null;
+        }
     }
 
-   public CompletableFuture<byte[]> handle(byte[] message) {
+    /**
+     * Handles a message from the other side of the connection.
+     */
+    @SuppressWarnings("unchecked")
+    private CompletableFuture<byte[]> handleMessage(byte[] message) {
         try {
             Object request = context.serializer().readObject(new ByteArrayInputStream(message));
             InternalHandler handler = handlers.get(request.getClass());
             if (handler == null) {
+                log.warn("No handler registered on connection {}-{} for type {}",
+                         partitionId, connectionId, request.getClass());
                 return Tools.exceptionalFuture(new IllegalStateException(
                         "No handler registered for " + request.getClass()));
             }
+
             return handler.handle(request).handle((result, error) -> {
                 try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
                     baos.write(error != null ? FAILURE : SUCCESS);
@@ -190,67 +199,159 @@
         }
     }
 
-    @Override
-    public Listener<Throwable> exceptionListener(Consumer<Throwable> listener) {
-        return exceptionListeners.add(listener);
+    /**
+     * Handles a close request from the other side of the connection.
+     */
+    private CompletableFuture<byte[]> handleClose() {
+        CompletableFuture<byte[]> future = new CompletableFuture<>();
+        context.executor().execute(() -> {
+            cleanup();
+            ByteBuffer responseBuffer = ByteBuffer.allocate(1);
+            responseBuffer.put(SUCCESS);
+            future.complete(responseBuffer.array());
+        });
+        return future;
     }
 
     @Override
-    public Listener<Connection> closeListener(Consumer<Connection> listener) {
-        return closeListeners.add(listener);
+    public <T, U> Connection handler(Class<T> type, MessageHandler<T, U> handler) {
+        if (log.isTraceEnabled()) {
+            log.trace("Registered handler on connection {}-{}: {}", partitionId, connectionId, type);
+        }
+        handlers.put(type, new InternalHandler(handler, ThreadContext.currentContextOrThrow()));
+        return this;
+    }
+
+    @Override
+    public Listener<Throwable> exceptionListener(Consumer<Throwable> consumer) {
+        return exceptionListeners.add(consumer);
+    }
+
+    @Override
+    public Listener<Connection> closeListener(Consumer<Connection> consumer) {
+        return closeListeners.add(consumer);
     }
 
     @Override
     public CompletableFuture<Void> close() {
-        closeListeners.forEach(listener -> listener.accept(this));
-        if (mode == CopycatTransport.Mode.CLIENT) {
-            messagingService.unregisterHandler(inboundMessageSubject);
-        }
-        return CompletableFuture.completedFuture(null);
+        log.debug("Closing connection {}-{}", partitionId, connectionId);
+
+        ByteBuffer requestBuffer = ByteBuffer.allocate(1);
+        requestBuffer.put(CLOSE);
+
+        ThreadContext context = ThreadContext.currentContextOrThrow();
+        CompletableFuture<Void> future = new CompletableFuture<>();
+        messagingService.sendAndReceive(endpoint, remoteSubject, requestBuffer.array(), context.executor())
+                .whenComplete((payload, error) -> {
+                    cleanup();
+                    Throwable wrappedError = error;
+                    if (error != null) {
+                        Throwable rootCause = Throwables.getRootCause(error);
+                        if (MessagingException.class.isAssignableFrom(rootCause.getClass())) {
+                            wrappedError = new TransportException(error);
+                        }
+                        future.completeExceptionally(wrappedError);
+                    } else {
+                        ByteBuffer responseBuffer = ByteBuffer.wrap(payload);
+                        if (responseBuffer.get() == SUCCESS) {
+                            future.complete(null);
+                        } else {
+                            future.completeExceptionally(new TransportException("Failed to close connection"));
+                        }
+                    }
+                });
+        return future;
     }
 
-    @Override
-    public int hashCode() {
-        return Objects.hash(connectionId);
+    /**
+     * Cleans up the connection, unregistering handlers registered on the MessagingService.
+     */
+    private void cleanup() {
+        log.debug("Connection {}-{} closed", partitionId, connectionId);
+        messagingService.unregisterHandler(localSubject);
+        closeListeners.accept(this);
     }
 
-    @Override
-    public boolean equals(Object other) {
-        if (!(other instanceof CopycatTransportConnection)) {
-            return false;
-        }
-        return connectionId == ((CopycatTransportConnection) other).connectionId;
+    /**
+     * Connection mode used to indicate whether this side of the connection is
+     * a client or server.
+     */
+    enum Mode {
+
+        /**
+         * Represents the client side of a bi-directional connection.
+         */
+        CLIENT {
+            @Override
+            String getLocalSubject(PartitionId partitionId, long connectionId) {
+                return String.format("onos-copycat-%s-%d-client", partitionId, connectionId);
+            }
+
+            @Override
+            String getRemoteSubject(PartitionId partitionId, long connectionId) {
+                return String.format("onos-copycat-%s-%d-server", partitionId, connectionId);
+            }
+        },
+
+        /**
+         * Represents the server side of a bi-directional connection.
+         */
+        SERVER {
+            @Override
+            String getLocalSubject(PartitionId partitionId, long connectionId) {
+                return String.format("onos-copycat-%s-%d-server", partitionId, connectionId);
+            }
+
+            @Override
+            String getRemoteSubject(PartitionId partitionId, long connectionId) {
+                return String.format("onos-copycat-%s-%d-client", partitionId, connectionId);
+            }
+        };
+
+        /**
+         * Returns the local messaging service subject for the connection in this mode.
+         * Subjects generated by the connection mode are guaranteed to be globally unique.
+         *
+         * @param partitionId the partition ID to which the connection belongs.
+         * @param connectionId the connection ID.
+         * @return the globally unique local subject for the connection.
+         */
+        abstract String getLocalSubject(PartitionId partitionId, long connectionId);
+
+        /**
+         * Returns the remote messaging service subject for the connection in this mode.
+         * Subjects generated by the connection mode are guaranteed to be globally unique.
+         *
+         * @param partitionId the partition ID to which the connection belongs.
+         * @param connectionId the connection ID.
+         * @return the globally unique remote subject for the connection.
+         */
+        abstract String getRemoteSubject(PartitionId partitionId, long connectionId);
     }
 
-    @Override
-    public String toString() {
-        return MoreObjects.toStringHelper(getClass())
-                .add("id", connectionId)
-                .toString();
-    }
-
-    @SuppressWarnings("rawtypes")
-    private final class InternalHandler {
-
+    /**
+     * Internal container for a handler/context pair.
+     */
+    private static class InternalHandler {
         private final MessageHandler handler;
         private final ThreadContext context;
 
-        private InternalHandler(MessageHandler handler, ThreadContext context) {
+        InternalHandler(MessageHandler handler, ThreadContext context) {
             this.handler = handler;
             this.context = context;
         }
 
         @SuppressWarnings("unchecked")
-        public CompletableFuture<Object> handle(Object message) {
-            CompletableFuture<Object> answer = new CompletableFuture<>();
+        CompletableFuture<Object> handle(Object message) {
+            CompletableFuture<Object> future = new CompletableFuture<>();
             context.execute(() -> handler.handle(message).whenComplete((r, e) -> {
                 if (e != null) {
-                    answer.completeExceptionally((Throwable) e);
+                    future.completeExceptionally((Throwable) e);
                 } else {
-                    answer.complete(r);
+                    future.complete(r);
                 }
             }));
-            return answer;
+            return future;
         }
     }
 }
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportServer.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportServer.java
index ff64112..aee7647 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportServer.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportServer.java
@@ -15,113 +15,99 @@
  */
 package org.onosproject.store.primitives.impl;
 
-import static com.google.common.base.Preconditions.checkNotNull;
-import static org.slf4j.LoggerFactory.getLogger;
-import io.atomix.catalyst.concurrent.CatalystThreadFactory;
-import io.atomix.catalyst.concurrent.SingleThreadContext;
+import com.google.common.collect.Sets;
 import io.atomix.catalyst.concurrent.ThreadContext;
 import io.atomix.catalyst.transport.Address;
 import io.atomix.catalyst.transport.Connection;
 import io.atomix.catalyst.transport.Server;
-
-import java.io.ByteArrayInputStream;
-import java.io.DataInputStream;
-import java.io.IOException;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Consumer;
-
-import org.apache.commons.io.IOUtils;
-import org.onlab.util.Tools;
+import org.apache.commons.lang3.RandomUtils;
 import org.onosproject.cluster.PartitionId;
 import org.onosproject.store.cluster.messaging.MessagingService;
 import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
-import com.google.common.collect.Maps;
+import java.nio.ByteBuffer;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static org.onosproject.store.primitives.impl.CopycatTransport.CONNECT;
+import static org.onosproject.store.primitives.impl.CopycatTransport.FAILURE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
 
 /**
- * {@link Server} implementation for {@link CopycatTransport}.
+ * Copycat transport server implementation.
  */
 public class CopycatTransportServer implements Server {
-
-    private final Logger log = getLogger(getClass());
-    private final AtomicBoolean listening = new AtomicBoolean(false);
-    private CompletableFuture<Void> listenFuture = new CompletableFuture<>();
-    private final ScheduledExecutorService executorService;
+    private final Logger log = LoggerFactory.getLogger(getClass());
     private final PartitionId partitionId;
+    private final String serverSubject;
     private final MessagingService messagingService;
-    private final String messageSubject;
-    private final Map<Long, CopycatTransportConnection> connections = Maps.newConcurrentMap();
+    private final Set<CopycatTransportConnection> connections = Sets.newConcurrentHashSet();
 
-    CopycatTransportServer(PartitionId partitionId, MessagingService messagingService) {
-        this.partitionId = checkNotNull(partitionId);
-        this.messagingService = checkNotNull(messagingService);
-        this.messageSubject = String.format("onos-copycat-%s", partitionId);
-        this.executorService = Executors.newScheduledThreadPool(Math.min(4, Runtime.getRuntime().availableProcessors()),
-                new CatalystThreadFactory("copycat-server-p" + partitionId + "-%d"));
+    public CopycatTransportServer(PartitionId partitionId, MessagingService messagingService) {
+        this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+        this.serverSubject = String.format("onos-copycat-%s", partitionId);
+        this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
     }
 
     @Override
-    public CompletableFuture<Void> listen(Address address, Consumer<Connection> listener) {
-        if (listening.compareAndSet(false, true)) {
-            ThreadContext context = ThreadContext.currentContextOrThrow();
-            listen(address, listener, context);
-        }
-        return listenFuture;
-    }
+    public CompletableFuture<Void> listen(Address address, Consumer<Connection> consumer) {
+        ThreadContext context = ThreadContext.currentContextOrThrow();
+        messagingService.registerHandler(serverSubject, (sender, payload) -> {
 
-    private void listen(Address address, Consumer<Connection> listener, ThreadContext context) {
-        messagingService.registerHandler(messageSubject, (sender, payload) -> {
-            try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
-                long connectionId = input.readLong();
-                AtomicBoolean newConnectionCreated = new AtomicBoolean(false);
-                CopycatTransportConnection connection = connections.computeIfAbsent(connectionId, k -> {
-                    newConnectionCreated.set(true);
-                    CopycatTransportConnection newConnection = new CopycatTransportConnection(connectionId,
-                            CopycatTransport.Mode.SERVER,
-                            partitionId,
-                            CopycatTransport.toAddress(sender),
-                            messagingService,
-                            getOrCreateContext(context));
-                    log.debug("Created new incoming connection {}", connectionId);
-                    newConnection.closeListener(c -> connections.remove(connectionId, c));
-                    return newConnection;
-                });
-                byte[] request = IOUtils.toByteArray(input);
-                return CompletableFuture.supplyAsync(
-                        () -> {
-                            if (newConnectionCreated.get()) {
-                                listener.accept(connection);
-                            }
-                            return connection;
-                        }, context.executor()).thenCompose(c -> c.handle(request));
-            } catch (IOException e) {
-                return Tools.exceptionalFuture(e);
+            // Only connect messages can be sent to the server. Once a connect message
+            // is received, the connection will register a separate handler for messaging.
+            ByteBuffer requestBuffer = ByteBuffer.wrap(payload);
+            if (requestBuffer.get() != CONNECT) {
+                ByteBuffer responseBuffer = ByteBuffer.allocate(1);
+                responseBuffer.put(FAILURE);
+                return CompletableFuture.completedFuture(responseBuffer.array());
             }
+
+            // Create the connection and ensure state is cleaned up when the connection is closed.
+            long connectionId = RandomUtils.nextLong(0L, Long.MAX_VALUE);
+            CopycatTransportConnection connection = new CopycatTransportConnection(
+                    connectionId,
+                    CopycatTransportConnection.Mode.SERVER,
+                    partitionId,
+                    sender,
+                    messagingService,
+                    context);
+            connection.closeListener(connections::remove);
+            connections.add(connection);
+
+            CompletableFuture<byte[]> future = new CompletableFuture<>();
+
+            // We need to ensure the connection event is called on the Copycat thread
+            // and that the future is not completed until the Copycat server has been
+            // able to register message handlers, otherwise some messages can be received
+            // prior to any handlers being registered.
+            context.executor().execute(() -> {
+                log.debug("Created connection {}-{}", partitionId, connectionId);
+                consumer.accept(connection);
+
+                ByteBuffer responseBuffer = ByteBuffer.allocate(9);
+                responseBuffer.put(SUCCESS);
+                responseBuffer.putLong(connectionId);
+                future.complete(responseBuffer.array());
+            });
+            return future;
         });
-        context.execute(() -> {
-            listenFuture.complete(null);
-        });
+        return CompletableFuture.completedFuture(null);
     }
 
     @Override
     public CompletableFuture<Void> close() {
-        messagingService.unregisterHandler(messageSubject);
-        executorService.shutdown();
-        return CompletableFuture.completedFuture(null);
+        return CompletableFuture.allOf(connections.stream().map(Connection::close).toArray(CompletableFuture[]::new));
     }
 
-    /**
-     * Returns the current execution context or creates one.
-     */
-    private ThreadContext getOrCreateContext(ThreadContext parentContext) {
-        ThreadContext context = ThreadContext.currentContext();
-        if (context != null) {
-            return context;
-        }
-        return new SingleThreadContext(executorService, parentContext.serializer().clone());
+    @Override
+    public String toString() {
+        return toStringHelper(this)
+                .add("partitionId", partitionId)
+                .toString();
     }
 }
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/StoragePartition.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/StoragePartition.java
index 7ed275a..a68b793 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/StoragePartition.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/StoragePartition.java
@@ -131,9 +131,7 @@
         StoragePartitionServer server = new StoragePartitionServer(toAddress(localNodeId),
                 this,
                 serializer,
-                () -> new CopycatTransport(CopycatTransport.Mode.SERVER,
-                                     partition.getId(),
-                                     messagingService),
+                () -> new CopycatTransport(partition.getId(), messagingService),
                 logFolder);
         return server.open().thenRun(() -> this.server = server);
     }
@@ -150,9 +148,7 @@
         StoragePartitionServer server = new StoragePartitionServer(toAddress(localNodeId),
                 this,
                 serializer,
-                () -> new CopycatTransport(CopycatTransport.Mode.SERVER,
-                                     partition.getId(),
-                                     messagingService),
+                () -> new CopycatTransport(partition.getId(), messagingService),
                 logFolder);
         return server.join(Collections2.transform(otherMembers, this::toAddress)).thenRun(() -> this.server = server);
     }
@@ -160,9 +156,7 @@
     private CompletableFuture<StoragePartitionClient> openClient() {
         client = new StoragePartitionClient(this,
                 serializer,
-                new CopycatTransport(CopycatTransport.Mode.CLIENT,
-                                     partition.getId(),
-                                     messagingService));
+                new CopycatTransport(partition.getId(), messagingService));
         return client.open().thenApply(v -> client);
     }
 
diff --git a/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java b/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java
new file mode 100644
index 0000000..21b4d70
--- /dev/null
+++ b/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java
@@ -0,0 +1,360 @@
+/*
+ * Copyright 2017-present Open Networking Laboratory
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.onosproject.store.primitives.impl;
+
+import com.google.common.collect.Lists;
+import io.atomix.catalyst.concurrent.SingleThreadContext;
+import io.atomix.catalyst.concurrent.ThreadContext;
+import io.atomix.catalyst.transport.Address;
+import io.atomix.catalyst.transport.Client;
+import io.atomix.catalyst.transport.Server;
+import io.atomix.catalyst.transport.Transport;
+import io.atomix.copycat.protocol.ConnectRequest;
+import io.atomix.copycat.protocol.ConnectResponse;
+import io.atomix.copycat.protocol.PublishRequest;
+import io.atomix.copycat.protocol.PublishResponse;
+import io.atomix.copycat.protocol.Response;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.onlab.packet.IpAddress;
+import org.onlab.util.Tools;
+import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingService;
+
+import java.time.Duration;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.fail;
+import static org.onlab.junit.TestTools.findAvailablePort;
+
+/**
+ * Copycat transport test.
+ */
+public class CopycatTransportTest {
+
+    private static final String IP_STRING = "127.0.0.1";
+
+    private Endpoint endpoint1 = new Endpoint(IpAddress.valueOf(IP_STRING), 5001);
+    private Endpoint endpoint2 = new Endpoint(IpAddress.valueOf(IP_STRING), 5002);
+
+    private MessagingService service1;
+    private MessagingService service2;
+
+    private Transport clientTransport;
+    private ThreadContext clientContext;
+
+    private Transport serverTransport;
+    private ThreadContext serverContext;
+
+    @Before
+    public void setUp() throws Exception {
+        Map<Endpoint, TestMessagingService> services = new ConcurrentHashMap<>();
+
+        endpoint1 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5001));
+        service1 = new TestMessagingService(endpoint1, services);
+        clientTransport = new CopycatTransport(PartitionId.from(1), service1);
+        clientContext = new SingleThreadContext("client-test-%d", CatalystSerializers.getSerializer());
+
+        endpoint2 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5003));
+        service2 = new TestMessagingService(endpoint2, services);
+        serverTransport = new CopycatTransport(PartitionId.from(1), service2);
+        serverContext = new SingleThreadContext("server-test-%d", CatalystSerializers.getSerializer());
+    }
+
+    @After
+    public void tearDown() throws Exception {
+        if (clientContext != null) {
+            clientContext.close();
+        }
+        if (serverContext != null) {
+            serverContext.close();
+        }
+    }
+
+    /**
+     * Tests sending a message from the client side of a Copycat connection to the server side.
+     */
+    @Test
+    public void testCopycatClientConnectionSend() throws Exception {
+        Client client = clientTransport.client();
+        Server server = serverTransport.server();
+
+        CountDownLatch latch = new CountDownLatch(4);
+        CountDownLatch listenLatch = new CountDownLatch(1);
+        CountDownLatch handlerLatch = new CountDownLatch(1);
+        serverContext.executor().execute(() -> {
+            server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+                serverContext.checkThread();
+                latch.countDown();
+                connection.handler(ConnectRequest.class, request -> {
+                    serverContext.checkThread();
+                    latch.countDown();
+                    return CompletableFuture.completedFuture(ConnectResponse.builder()
+                            .withStatus(Response.Status.OK)
+                            .withLeader(new Address(IP_STRING, endpoint2.port()))
+                            .withMembers(Lists.newArrayList(new Address(IP_STRING, endpoint2.port())))
+                            .build());
+                });
+                handlerLatch.countDown();
+            }).thenRun(listenLatch::countDown);
+        });
+
+        listenLatch.await(5, TimeUnit.SECONDS);
+
+        clientContext.executor().execute(() -> {
+            client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+                clientContext.checkThread();
+                latch.countDown();
+                try {
+                    handlerLatch.await(5, TimeUnit.SECONDS);
+                } catch (InterruptedException e) {
+                    fail();
+                }
+                connection.<ConnectRequest, ConnectResponse>send(ConnectRequest.builder()
+                        .withClientId(UUID.randomUUID().toString())
+                        .build())
+                        .thenAccept(response -> {
+                            clientContext.checkThread();
+                            assertNotNull(response);
+                            assertEquals(Response.Status.OK, response.status());
+                            latch.countDown();
+                        });
+            });
+        });
+
+        latch.await(5, TimeUnit.SECONDS);
+        assertEquals(0, latch.getCount());
+    }
+
+    /**
+     * Tests sending a message from the server side of a Copycat connection to the client side.
+     */
+    @Test
+    public void testCopycatServerConnectionSend() throws Exception {
+        Client client = clientTransport.client();
+        Server server = serverTransport.server();
+
+        CountDownLatch latch = new CountDownLatch(4);
+        CountDownLatch listenLatch = new CountDownLatch(1);
+        serverContext.executor().execute(() -> {
+            server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+                serverContext.checkThread();
+                latch.countDown();
+                serverContext.schedule(Duration.ofMillis(100), () -> {
+                    connection.<PublishRequest, PublishResponse>send(PublishRequest.builder()
+                            .withSession(1)
+                            .withEventIndex(3)
+                            .withPreviousIndex(2)
+                            .build())
+                            .thenAccept(response -> {
+                                serverContext.checkThread();
+                                assertEquals(Response.Status.OK, response.status());
+                                assertEquals(1, response.index());
+                                latch.countDown();
+                            });
+                });
+            }).thenRun(listenLatch::countDown);
+        });
+
+        listenLatch.await(5, TimeUnit.SECONDS);
+
+        clientContext.executor().execute(() -> {
+            client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+                clientContext.checkThread();
+                latch.countDown();
+                connection.handler(PublishRequest.class, request -> {
+                    clientContext.checkThread();
+                    latch.countDown();
+                    assertEquals(1, request.session());
+                    assertEquals(3, request.eventIndex());
+                    assertEquals(2, request.previousIndex());
+                    return CompletableFuture.completedFuture(PublishResponse.builder()
+                            .withStatus(Response.Status.OK)
+                            .withIndex(1)
+                            .build());
+                });
+            });
+        });
+
+        latch.await(5, TimeUnit.SECONDS);
+        assertEquals(0, latch.getCount());
+    }
+
+    /**
+     * Tests closing the server side of a Copycat connection.
+     */
+    @Test
+    public void testCopycatClientConnectionClose() throws Exception {
+        Client client = clientTransport.client();
+        Server server = serverTransport.server();
+
+        CountDownLatch latch = new CountDownLatch(5);
+        CountDownLatch listenLatch = new CountDownLatch(1);
+        serverContext.executor().execute(() -> {
+            server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+                serverContext.checkThread();
+                latch.countDown();
+                connection.closeListener(c -> {
+                    serverContext.checkThread();
+                    latch.countDown();
+                });
+            }).thenRun(listenLatch::countDown);
+        });
+
+        listenLatch.await(5, TimeUnit.SECONDS);
+
+        clientContext.executor().execute(() -> {
+            client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+                clientContext.checkThread();
+                latch.countDown();
+                connection.closeListener(c -> {
+                    clientContext.checkThread();
+                    latch.countDown();
+                });
+                clientContext.schedule(Duration.ofMillis(100), () -> {
+                    connection.close().whenComplete((result, error) -> {
+                        clientContext.checkThread();
+                        latch.countDown();
+                    });
+                });
+            });
+        });
+
+        latch.await(5, TimeUnit.SECONDS);
+        assertEquals(0, latch.getCount());
+    }
+
+    /**
+     * Tests closing the server side of a Copycat connection.
+     */
+    @Test
+    public void testCopycatServerConnectionClose() throws Exception {
+        Client client = clientTransport.client();
+        Server server = serverTransport.server();
+
+        CountDownLatch latch = new CountDownLatch(5);
+        CountDownLatch listenLatch = new CountDownLatch(1);
+        serverContext.executor().execute(() -> {
+            server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+                serverContext.checkThread();
+                latch.countDown();
+                connection.closeListener(c -> {
+                    latch.countDown();
+                });
+                serverContext.schedule(Duration.ofMillis(100), () -> {
+                    connection.close().whenComplete((result, error) -> {
+                        serverContext.checkThread();
+                        latch.countDown();
+                    });
+                });
+            }).thenRun(listenLatch::countDown);
+        });
+
+        listenLatch.await(5, TimeUnit.SECONDS);
+
+        clientContext.executor().execute(() -> {
+            client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+                clientContext.checkThread();
+                latch.countDown();
+                connection.closeListener(c -> {
+                    latch.countDown();
+                });
+            });
+        });
+
+        latch.await(5, TimeUnit.SECONDS);
+        assertEquals(0, latch.getCount());
+    }
+
+    /**
+     * Custom implementation of {@code MessagingService} used for testing. Really, this should
+     * be mocked but suffices for now.
+     */
+    public static final class TestMessagingService implements MessagingService {
+        private final Endpoint endpoint;
+        private final Map<Endpoint, TestMessagingService> services;
+        private final Map<String, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>>> handlers =
+                new ConcurrentHashMap<>();
+
+        TestMessagingService(Endpoint endpoint, Map<Endpoint, TestMessagingService> services) {
+            this.endpoint = endpoint;
+            this.services = services;
+            services.put(endpoint, this);
+        }
+
+        private CompletableFuture<byte[]> handle(Endpoint ep, String type, byte[] message, Executor executor) {
+            BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler = handlers.get(type);
+            if (handler == null) {
+                return Tools.exceptionalFuture(new IllegalStateException());
+            }
+            return handler.apply(ep, message).thenApplyAsync(r -> r, executor);
+        }
+
+        @Override
+        public CompletableFuture<Void> sendAsync(Endpoint ep, String type, byte[] payload) {
+            // Unused for testing
+            return null;
+        }
+
+        @Override
+        public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload) {
+            // Unused for testing
+            return null;
+        }
+
+        @Override
+        public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor) {
+            TestMessagingService service = services.get(ep);
+            if (service == null) {
+                return Tools.exceptionalFuture(new IllegalStateException());
+            }
+            return service.handle(endpoint, type, payload, executor);
+        }
+
+        @Override
+        public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) {
+            // Unused for testing
+        }
+
+        @Override
+        public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) {
+            // Unused for testing
+        }
+
+        @Override
+        public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) {
+            handlers.put(type, handler);
+        }
+
+        @Override
+        public void unregisterHandler(String type) {
+            handlers.remove(type);
+        }
+    }
+
+}
\ No newline at end of file