[ONOS-6075] Rewrite Copycat Transport
- Ensure connection IDs are globally unique
- Ensure connections are closed on each side when close() is called
- Add Transport unit tests
Change-Id: Ia848b075d4030ce74293ecc57fea983693cee265
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
index f2752cd..9d7edd0 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
@@ -15,24 +15,23 @@
*/
package org.onosproject.store.primitives.impl;
-import com.google.common.base.MoreObjects;
import com.google.common.base.Throwables;
-import com.google.common.collect.Maps;
import io.atomix.catalyst.concurrent.Listener;
import io.atomix.catalyst.concurrent.Listeners;
import io.atomix.catalyst.concurrent.ThreadContext;
import io.atomix.catalyst.serializer.SerializationException;
-import io.atomix.catalyst.transport.Address;
import io.atomix.catalyst.transport.Connection;
import io.atomix.catalyst.transport.MessageHandler;
import io.atomix.catalyst.transport.TransportException;
-import io.atomix.catalyst.util.Assert;
import io.atomix.catalyst.util.reference.ReferenceCounted;
import org.apache.commons.io.IOUtils;
import org.onlab.util.Tools;
import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
import org.onosproject.store.cluster.messaging.MessagingException;
import org.onosproject.store.cluster.messaging.MessagingService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@@ -40,79 +39,64 @@
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.nio.ByteBuffer;
import java.util.Map;
-import java.util.Objects;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import static com.google.common.base.Preconditions.checkNotNull;
+import static org.onosproject.store.primitives.impl.CopycatTransport.CLOSE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.FAILURE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.MESSAGE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
/**
- * {@link Connection} implementation for CopycatTransport.
+ * Base Copycat Transport connection.
*/
public class CopycatTransportConnection implements Connection {
-
+ private final Logger log = LoggerFactory.getLogger(getClass());
+ private final long connectionId;
+ private final String localSubject;
+ private final String remoteSubject;
+ private final PartitionId partitionId;
+ private final Endpoint endpoint;
+ private final MessagingService messagingService;
+ private final ThreadContext context;
+ private final Map<Class, InternalHandler> handlers = new ConcurrentHashMap<>();
private final Listeners<Throwable> exceptionListeners = new Listeners<>();
private final Listeners<Connection> closeListeners = new Listeners<>();
- static final byte SUCCESS = 0x03;
- static final byte FAILURE = 0x04;
-
- private final long connectionId;
- private final CopycatTransport.Mode mode;
- private final Address remoteAddress;
- private final MessagingService messagingService;
- private final String outboundMessageSubject;
- private final String inboundMessageSubject;
- private final ThreadContext context;
- private final Map<Class<?>, InternalHandler> handlers = Maps.newConcurrentMap();
-
- CopycatTransportConnection(long connectionId,
- CopycatTransport.Mode mode,
+ CopycatTransportConnection(
+ long connectionId,
+ Mode mode,
PartitionId partitionId,
- Address address,
+ Endpoint endpoint,
MessagingService messagingService,
ThreadContext context) {
this.connectionId = connectionId;
- this.mode = checkNotNull(mode);
- this.remoteAddress = checkNotNull(address);
- this.messagingService = checkNotNull(messagingService);
- if (mode == CopycatTransport.Mode.CLIENT) {
- this.outboundMessageSubject = String.format("onos-copycat-%s", partitionId);
- this.inboundMessageSubject = String.format("onos-copycat-%s-%d", partitionId, connectionId);
- } else {
- this.outboundMessageSubject = String.format("onos-copycat-%s-%d", partitionId, connectionId);
- this.inboundMessageSubject = String.format("onos-copycat-%s", partitionId);
- }
- this.context = checkNotNull(context);
- }
-
- public void setBidirectional() {
- messagingService.registerHandler(inboundMessageSubject, (sender, payload) -> {
- try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
- if (input.readLong() != connectionId) {
- throw new IllegalStateException("Invalid connection Id");
- }
- return handle(IOUtils.toByteArray(input));
- } catch (IOException e) {
- Throwables.propagate(e);
- return null;
- }
- });
+ this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+ this.localSubject = mode.getLocalSubject(partitionId, connectionId);
+ this.remoteSubject = mode.getRemoteSubject(partitionId, connectionId);
+ this.endpoint = checkNotNull(endpoint, "endpoint cannot be null");
+ this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
+ this.context = checkNotNull(context, "context cannot be null");
+ messagingService.registerHandler(localSubject, this::handle);
}
@Override
public <T, U> CompletableFuture<U> send(T message) {
ThreadContext context = ThreadContext.currentContextOrThrow();
- CompletableFuture<U> result = new CompletableFuture<>();
+ CompletableFuture<U> future = new CompletableFuture<>();
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
- new DataOutputStream(baos).writeLong(connectionId);
+ DataOutputStream dos = new DataOutputStream(baos);
+ dos.writeByte(MESSAGE);
context.serializer().writeObject(message, baos);
if (message instanceof ReferenceCounted) {
((ReferenceCounted<?>) message).release();
}
- messagingService.sendAndReceive(CopycatTransport.toEndpoint(remoteAddress),
- outboundMessageSubject,
+ messagingService.sendAndReceive(endpoint,
+ remoteSubject,
baos.toByteArray(),
context.executor())
.whenComplete((r, e) -> {
@@ -123,20 +107,23 @@
wrappedError = new TransportException(e);
}
}
- handleResponse(r, wrappedError, result, context);
+ handleResponse(r, wrappedError, future);
});
} catch (SerializationException | IOException e) {
- result.completeExceptionally(e);
+ future.completeExceptionally(e);
}
- return result;
+ return future;
}
- private <T> void handleResponse(byte[] response,
- Throwable error,
- CompletableFuture<T> future,
- ThreadContext context) {
+ /**
+ * Handles a response received from the other side of the connection.
+ */
+ private <T> void handleResponse(
+ byte[] response,
+ Throwable error,
+ CompletableFuture<T> future) {
if (error != null) {
- context.execute(() -> future.completeExceptionally(error));
+ future.completeExceptionally(error);
return;
}
checkNotNull(response);
@@ -145,36 +132,54 @@
byte status = (byte) input.read();
if (status == FAILURE) {
Throwable t = context.serializer().readObject(input);
- context.execute(() -> future.completeExceptionally(t));
+ future.completeExceptionally(t);
} else {
- context.execute(() -> {
- try {
- future.complete(context.serializer().readObject(input));
- } catch (SerializationException e) {
- future.completeExceptionally(e);
- }
- });
+ try {
+ future.complete(context.serializer().readObject(input));
+ } catch (SerializationException e) {
+ future.completeExceptionally(e);
+ }
}
} catch (IOException e) {
- context.execute(() -> future.completeExceptionally(e));
+ future.completeExceptionally(e);
}
}
- @Override
- public <T, U> Connection handler(Class<T> type, MessageHandler<T, U> handler) {
- Assert.notNull(type, "type");
- handlers.put(type, new InternalHandler(handler, ThreadContext.currentContextOrThrow()));
- return null;
+ /**
+ * Handles a message sent to the connection.
+ */
+ private CompletableFuture<byte[]> handle(Endpoint sender, byte[] payload) {
+ try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
+ byte type = input.readByte();
+ switch (type) {
+ case MESSAGE:
+ return handleMessage(IOUtils.toByteArray(input));
+ case CLOSE:
+ return handleClose();
+ default:
+ throw new IllegalStateException("Invalid message type");
+ }
+ } catch (IOException e) {
+ Throwables.propagate(e);
+ return null;
+ }
}
- public CompletableFuture<byte[]> handle(byte[] message) {
+ /**
+ * Handles a message from the other side of the connection.
+ */
+ @SuppressWarnings("unchecked")
+ private CompletableFuture<byte[]> handleMessage(byte[] message) {
try {
Object request = context.serializer().readObject(new ByteArrayInputStream(message));
InternalHandler handler = handlers.get(request.getClass());
if (handler == null) {
+ log.warn("No handler registered on connection {}-{} for type {}",
+ partitionId, connectionId, request.getClass());
return Tools.exceptionalFuture(new IllegalStateException(
"No handler registered for " + request.getClass()));
}
+
return handler.handle(request).handle((result, error) -> {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
baos.write(error != null ? FAILURE : SUCCESS);
@@ -190,67 +195,159 @@
}
}
- @Override
- public Listener<Throwable> exceptionListener(Consumer<Throwable> listener) {
- return exceptionListeners.add(listener);
+ /**
+ * Handles a close request from the other side of the connection.
+ */
+ private CompletableFuture<byte[]> handleClose() {
+ CompletableFuture<byte[]> future = new CompletableFuture<>();
+ context.executor().execute(() -> {
+ cleanup();
+ ByteBuffer responseBuffer = ByteBuffer.allocate(1);
+ responseBuffer.put(SUCCESS);
+ future.complete(responseBuffer.array());
+ });
+ return future;
}
@Override
- public Listener<Connection> closeListener(Consumer<Connection> listener) {
- return closeListeners.add(listener);
+ public <T, U> Connection handler(Class<T> type, MessageHandler<T, U> handler) {
+ if (log.isTraceEnabled()) {
+ log.trace("Registered handler on connection {}-{}: {}", partitionId, connectionId, type);
+ }
+ handlers.put(type, new InternalHandler(handler, ThreadContext.currentContextOrThrow()));
+ return this;
+ }
+
+ @Override
+ public Listener<Throwable> exceptionListener(Consumer<Throwable> consumer) {
+ return exceptionListeners.add(consumer);
+ }
+
+ @Override
+ public Listener<Connection> closeListener(Consumer<Connection> consumer) {
+ return closeListeners.add(consumer);
}
@Override
public CompletableFuture<Void> close() {
- closeListeners.forEach(listener -> listener.accept(this));
- if (mode == CopycatTransport.Mode.CLIENT) {
- messagingService.unregisterHandler(inboundMessageSubject);
- }
- return CompletableFuture.completedFuture(null);
+ log.debug("Closing connection {}-{}", partitionId, connectionId);
+
+ ByteBuffer requestBuffer = ByteBuffer.allocate(1);
+ requestBuffer.put(CLOSE);
+
+ ThreadContext context = ThreadContext.currentContextOrThrow();
+ CompletableFuture<Void> future = new CompletableFuture<>();
+ messagingService.sendAndReceive(endpoint, remoteSubject, requestBuffer.array(), context.executor())
+ .whenComplete((payload, error) -> {
+ cleanup();
+ Throwable wrappedError = error;
+ if (error != null) {
+ Throwable rootCause = Throwables.getRootCause(error);
+ if (MessagingException.class.isAssignableFrom(rootCause.getClass())) {
+ wrappedError = new TransportException(error);
+ }
+ future.completeExceptionally(wrappedError);
+ } else {
+ ByteBuffer responseBuffer = ByteBuffer.wrap(payload);
+ if (responseBuffer.get() == SUCCESS) {
+ future.complete(null);
+ } else {
+ future.completeExceptionally(new TransportException("Failed to close connection"));
+ }
+ }
+ });
+ return future;
}
- @Override
- public int hashCode() {
- return Objects.hash(connectionId);
+ /**
+ * Cleans up the connection, unregistering handlers registered on the MessagingService.
+ */
+ private void cleanup() {
+ log.debug("Connection {}-{} closed", partitionId, connectionId);
+ messagingService.unregisterHandler(localSubject);
+ closeListeners.accept(this);
}
- @Override
- public boolean equals(Object other) {
- if (!(other instanceof CopycatTransportConnection)) {
- return false;
- }
- return connectionId == ((CopycatTransportConnection) other).connectionId;
+ /**
+ * Connection mode used to indicate whether this side of the connection is
+ * a client or server.
+ */
+ enum Mode {
+
+ /**
+ * Represents the client side of a bi-directional connection.
+ */
+ CLIENT {
+ @Override
+ String getLocalSubject(PartitionId partitionId, long connectionId) {
+ return String.format("onos-copycat-%s-%d-client", partitionId, connectionId);
+ }
+
+ @Override
+ String getRemoteSubject(PartitionId partitionId, long connectionId) {
+ return String.format("onos-copycat-%s-%d-server", partitionId, connectionId);
+ }
+ },
+
+ /**
+ * Represents the server side of a bi-directional connection.
+ */
+ SERVER {
+ @Override
+ String getLocalSubject(PartitionId partitionId, long connectionId) {
+ return String.format("onos-copycat-%s-%d-server", partitionId, connectionId);
+ }
+
+ @Override
+ String getRemoteSubject(PartitionId partitionId, long connectionId) {
+ return String.format("onos-copycat-%s-%d-client", partitionId, connectionId);
+ }
+ };
+
+ /**
+ * Returns the local messaging service subject for the connection in this mode.
+ * Subjects generated by the connection mode are guaranteed to be globally unique.
+ *
+ * @param partitionId the partition ID to which the connection belongs.
+ * @param connectionId the connection ID.
+ * @return the globally unique local subject for the connection.
+ */
+ abstract String getLocalSubject(PartitionId partitionId, long connectionId);
+
+ /**
+ * Returns the remote messaging service subject for the connection in this mode.
+ * Subjects generated by the connection mode are guaranteed to be globally unique.
+ *
+ * @param partitionId the partition ID to which the connection belongs.
+ * @param connectionId the connection ID.
+ * @return the globally unique remote subject for the connection.
+ */
+ abstract String getRemoteSubject(PartitionId partitionId, long connectionId);
}
- @Override
- public String toString() {
- return MoreObjects.toStringHelper(getClass())
- .add("id", connectionId)
- .toString();
- }
-
- @SuppressWarnings("rawtypes")
- private final class InternalHandler {
-
+ /**
+ * Internal container for a handler/context pair.
+ */
+ private static class InternalHandler {
private final MessageHandler handler;
private final ThreadContext context;
- private InternalHandler(MessageHandler handler, ThreadContext context) {
+ InternalHandler(MessageHandler handler, ThreadContext context) {
this.handler = handler;
this.context = context;
}
@SuppressWarnings("unchecked")
- public CompletableFuture<Object> handle(Object message) {
- CompletableFuture<Object> answer = new CompletableFuture<>();
+ CompletableFuture<Object> handle(Object message) {
+ CompletableFuture<Object> future = new CompletableFuture<>();
context.execute(() -> handler.handle(message).whenComplete((r, e) -> {
if (e != null) {
- answer.completeExceptionally((Throwable) e);
+ future.completeExceptionally((Throwable) e);
} else {
- answer.complete(r);
+ future.complete(r);
}
}));
- return answer;
+ return future;
}
}
}