[ONOS-6075] Rewrite Copycat Transport
- Ensure connection IDs are globally unique
- Ensure connections are closed on each side when close() is called
- Add Transport unit tests
Change-Id: Ia848b075d4030ce74293ecc57fea983693cee265
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
index 17fa742..45cc529 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
@@ -15,73 +15,59 @@
*/
package org.onosproject.store.primitives.impl;
-import static com.google.common.base.Preconditions.checkNotNull;
+import com.google.common.base.Throwables;
+import com.google.common.collect.Maps;
+import io.atomix.catalyst.transport.Address;
+import io.atomix.catalyst.transport.Client;
+import io.atomix.catalyst.transport.Server;
+import io.atomix.catalyst.transport.Transport;
+import org.onlab.packet.IpAddress;
+import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingService;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.UnknownHostException;
import java.util.Map;
-import org.onlab.packet.IpAddress;
-import org.onosproject.cluster.PartitionId;
-import org.onosproject.store.cluster.messaging.Endpoint;
-import org.onosproject.store.cluster.messaging.MessagingService;
-
-import com.google.common.base.Throwables;
-import com.google.common.collect.Maps;
-
-import io.atomix.catalyst.transport.Address;
-import io.atomix.catalyst.transport.Client;
-import io.atomix.catalyst.transport.Server;
-import io.atomix.catalyst.transport.Transport;
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkNotNull;
/**
- * Custom {@link Transport transport} for Copycat interactions
- * built on top of {@link MessagingService}.
- *
- * @see CopycatTransportServer
- * @see CopycatTransportClient
+ * Copycat transport implementation built on {@link MessagingService}.
*/
public class CopycatTransport implements Transport {
-
- /**
- * Transport Mode.
- */
- public enum Mode {
- /**
- * Signifies transport for client {@literal ->} server interaction.
- */
- CLIENT,
-
- /**
- * Signified transport for server {@literal ->} server interaction.
- */
- SERVER
- }
-
- private final Mode mode;
private final PartitionId partitionId;
private final MessagingService messagingService;
private static final Map<Address, Endpoint> EP_LOOKUP_CACHE = Maps.newConcurrentMap();
private static final Map<Endpoint, Address> ADDRESS_LOOKUP_CACHE = Maps.newConcurrentMap();
- public CopycatTransport(Mode mode, PartitionId partitionId, MessagingService messagingService) {
- this.mode = checkNotNull(mode);
- this.partitionId = checkNotNull(partitionId);
- this.messagingService = checkNotNull(messagingService);
+ static final byte MESSAGE = 0x01;
+ static final byte CONNECT = 0x02;
+ static final byte CLOSE = 0x03;
+
+ static final byte SUCCESS = 0x01;
+ static final byte FAILURE = 0x02;
+
+ public CopycatTransport(PartitionId partitionId, MessagingService messagingService) {
+ this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+ this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
}
@Override
public Client client() {
- return new CopycatTransportClient(partitionId,
- messagingService,
- mode);
+ return new CopycatTransportClient(partitionId, messagingService);
}
@Override
public Server server() {
- return new CopycatTransportServer(partitionId,
- messagingService);
+ return new CopycatTransportServer(partitionId, messagingService);
+ }
+
+ @Override
+ public String toString() {
+ return toStringHelper(this).toString();
}
/**
@@ -89,7 +75,7 @@
* @param address address
* @return end point
*/
- public static Endpoint toEndpoint(Address address) {
+ static Endpoint toEndpoint(Address address) {
return EP_LOOKUP_CACHE.computeIfAbsent(address, a -> {
try {
return new Endpoint(IpAddress.valueOf(InetAddress.getByName(a.host())), a.port());
@@ -105,7 +91,7 @@
* @param endpoint end point
* @return address
*/
- public static Address toAddress(Endpoint endpoint) {
+ static Address toAddress(Endpoint endpoint) {
return ADDRESS_LOOKUP_CACHE.computeIfAbsent(endpoint, ep -> {
try {
InetAddress host = InetAddress.getByAddress(endpoint.host().toOctets());
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportClient.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportClient.java
index 7a1cecf..3567338 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportClient.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportClient.java
@@ -15,53 +15,94 @@
*/
package org.onosproject.store.primitives.impl;
-import static com.google.common.base.Preconditions.checkNotNull;
-
-import java.util.Set;
-import java.util.concurrent.CompletableFuture;
-
-import org.apache.commons.lang.math.RandomUtils;
-import org.onosproject.cluster.PartitionId;
-import org.onosproject.store.cluster.messaging.MessagingService;
-
+import com.google.common.base.Throwables;
import com.google.common.collect.Sets;
-
+import io.atomix.catalyst.concurrent.ThreadContext;
import io.atomix.catalyst.transport.Address;
import io.atomix.catalyst.transport.Client;
import io.atomix.catalyst.transport.Connection;
-import io.atomix.catalyst.concurrent.ThreadContext;
+import io.atomix.catalyst.transport.TransportException;
+import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingException;
+import org.onosproject.store.cluster.messaging.MessagingService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.net.ConnectException;
+import java.nio.ByteBuffer;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static org.onosproject.store.primitives.impl.CopycatTransport.CONNECT;
+import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
/**
- * {@link Client} implementation for {@link CopycatTransport}.
+ * Copycat transport client implementation.
*/
public class CopycatTransportClient implements Client {
-
+ private final Logger log = LoggerFactory.getLogger(getClass());
private final PartitionId partitionId;
+ private final String serverSubject;
private final MessagingService messagingService;
- private final CopycatTransport.Mode mode;
private final Set<CopycatTransportConnection> connections = Sets.newConcurrentHashSet();
- CopycatTransportClient(PartitionId partitionId, MessagingService messagingService, CopycatTransport.Mode mode) {
- this.partitionId = checkNotNull(partitionId);
- this.messagingService = checkNotNull(messagingService);
- this.mode = checkNotNull(mode);
+ public CopycatTransportClient(PartitionId partitionId, MessagingService messagingService) {
+ this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+ this.serverSubject = String.format("onos-copycat-%s", partitionId);
+ this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
}
@Override
- public CompletableFuture<Connection> connect(Address remoteAddress) {
+ public CompletableFuture<Connection> connect(Address address) {
+ CompletableFuture<Connection> future = new CompletableFuture<>();
ThreadContext context = ThreadContext.currentContextOrThrow();
- CopycatTransportConnection connection = new CopycatTransportConnection(
- nextConnectionId(),
- CopycatTransport.Mode.CLIENT,
- partitionId,
- remoteAddress,
- messagingService,
- context);
- if (mode == CopycatTransport.Mode.CLIENT) {
- connection.setBidirectional();
- }
- connections.add(connection);
- return CompletableFuture.supplyAsync(() -> connection, context.executor());
+
+ Endpoint endpoint = CopycatTransport.toEndpoint(address);
+
+ log.debug("Connecting to {}", address);
+
+ ByteBuffer requestBuffer = ByteBuffer.allocate(1);
+ requestBuffer.put(CONNECT);
+
+ // Send a connect request to the server to get a unique connection ID.
+ messagingService.sendAndReceive(endpoint, serverSubject, requestBuffer.array(), context.executor())
+ .whenComplete((payload, error) -> {
+ Throwable wrappedError = error;
+ if (error != null) {
+ Throwable rootCause = Throwables.getRootCause(error);
+ if (MessagingException.class.isAssignableFrom(rootCause.getClass())) {
+ wrappedError = new TransportException(error);
+ }
+ log.warn("Connection to {} failed! Reason: {}", address, wrappedError);
+ future.completeExceptionally(wrappedError);
+ } else {
+ // If the connection is successful, the server will send back a
+ // connection ID indicating where to send messages for the connection.
+ ByteBuffer responseBuffer = ByteBuffer.wrap(payload);
+ if (responseBuffer.get() == SUCCESS) {
+ long connectionId = responseBuffer.getLong();
+ CopycatTransportConnection connection = new CopycatTransportConnection(
+ connectionId,
+ CopycatTransportConnection.Mode.CLIENT,
+ partitionId,
+ endpoint,
+ messagingService,
+ context);
+ connection.closeListener(connections::remove);
+ connections.add(connection);
+ future.complete(connection);
+ log.debug("Created connection {}-{} to {}", partitionId, connectionId, address);
+ } else {
+ log.warn("Connection to {} failed!");
+ future.completeExceptionally(new ConnectException());
+ }
+ }
+ });
+ return future;
+
}
@Override
@@ -69,7 +110,11 @@
return CompletableFuture.allOf(connections.stream().map(Connection::close).toArray(CompletableFuture[]::new));
}
- private long nextConnectionId() {
- return RandomUtils.nextLong();
+ @Override
+ public String toString() {
+ return toStringHelper(this)
+ .add("partitionId", partitionId)
+ .toString();
}
- }
+}
+
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
index f2752cd..fadb6dd 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
@@ -15,24 +15,15 @@
*/
package org.onosproject.store.primitives.impl;
-import com.google.common.base.MoreObjects;
-import com.google.common.base.Throwables;
-import com.google.common.collect.Maps;
+
import io.atomix.catalyst.concurrent.Listener;
import io.atomix.catalyst.concurrent.Listeners;
import io.atomix.catalyst.concurrent.ThreadContext;
import io.atomix.catalyst.serializer.SerializationException;
-import io.atomix.catalyst.transport.Address;
import io.atomix.catalyst.transport.Connection;
import io.atomix.catalyst.transport.MessageHandler;
import io.atomix.catalyst.transport.TransportException;
-import io.atomix.catalyst.util.Assert;
import io.atomix.catalyst.util.reference.ReferenceCounted;
-import org.apache.commons.io.IOUtils;
-import org.onlab.util.Tools;
-import org.onosproject.cluster.PartitionId;
-import org.onosproject.store.cluster.messaging.MessagingException;
-import org.onosproject.store.cluster.messaging.MessagingService;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
@@ -40,79 +31,76 @@
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
+import java.nio.ByteBuffer;
import java.util.Map;
-import java.util.Objects;
import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
+import org.apache.commons.io.IOUtils;
+import org.onlab.util.Tools;
+import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingException;
+import org.onosproject.store.cluster.messaging.MessagingService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import com.google.common.base.Throwables;
+
import static com.google.common.base.Preconditions.checkNotNull;
+import static org.onosproject.store.primitives.impl.CopycatTransport.CLOSE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.FAILURE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.MESSAGE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
+
/**
- * {@link Connection} implementation for CopycatTransport.
+ * Base Copycat Transport connection.
*/
public class CopycatTransportConnection implements Connection {
-
+ private final Logger log = LoggerFactory.getLogger(getClass());
+ private final long connectionId;
+ private final String localSubject;
+ private final String remoteSubject;
+ private final PartitionId partitionId;
+ private final Endpoint endpoint;
+ private final MessagingService messagingService;
+ private final ThreadContext context;
+ private final Map<Class, InternalHandler> handlers = new ConcurrentHashMap<>();
private final Listeners<Throwable> exceptionListeners = new Listeners<>();
private final Listeners<Connection> closeListeners = new Listeners<>();
- static final byte SUCCESS = 0x03;
- static final byte FAILURE = 0x04;
-
- private final long connectionId;
- private final CopycatTransport.Mode mode;
- private final Address remoteAddress;
- private final MessagingService messagingService;
- private final String outboundMessageSubject;
- private final String inboundMessageSubject;
- private final ThreadContext context;
- private final Map<Class<?>, InternalHandler> handlers = Maps.newConcurrentMap();
-
- CopycatTransportConnection(long connectionId,
- CopycatTransport.Mode mode,
+ CopycatTransportConnection(
+ long connectionId,
+ Mode mode,
PartitionId partitionId,
- Address address,
+ Endpoint endpoint,
MessagingService messagingService,
ThreadContext context) {
this.connectionId = connectionId;
- this.mode = checkNotNull(mode);
- this.remoteAddress = checkNotNull(address);
- this.messagingService = checkNotNull(messagingService);
- if (mode == CopycatTransport.Mode.CLIENT) {
- this.outboundMessageSubject = String.format("onos-copycat-%s", partitionId);
- this.inboundMessageSubject = String.format("onos-copycat-%s-%d", partitionId, connectionId);
- } else {
- this.outboundMessageSubject = String.format("onos-copycat-%s-%d", partitionId, connectionId);
- this.inboundMessageSubject = String.format("onos-copycat-%s", partitionId);
- }
- this.context = checkNotNull(context);
- }
-
- public void setBidirectional() {
- messagingService.registerHandler(inboundMessageSubject, (sender, payload) -> {
- try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
- if (input.readLong() != connectionId) {
- throw new IllegalStateException("Invalid connection Id");
- }
- return handle(IOUtils.toByteArray(input));
- } catch (IOException e) {
- Throwables.propagate(e);
- return null;
- }
- });
+ this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+ this.localSubject = mode.getLocalSubject(partitionId, connectionId);
+ this.remoteSubject = mode.getRemoteSubject(partitionId, connectionId);
+ this.endpoint = checkNotNull(endpoint, "endpoint cannot be null");
+ this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
+ this.context = checkNotNull(context, "context cannot be null");
+ messagingService.registerHandler(localSubject, this::handle);
}
@Override
public <T, U> CompletableFuture<U> send(T message) {
ThreadContext context = ThreadContext.currentContextOrThrow();
- CompletableFuture<U> result = new CompletableFuture<>();
+ CompletableFuture<U> future = new CompletableFuture<>();
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
- new DataOutputStream(baos).writeLong(connectionId);
+ DataOutputStream dos = new DataOutputStream(baos);
+ dos.writeByte(MESSAGE);
context.serializer().writeObject(message, baos);
if (message instanceof ReferenceCounted) {
((ReferenceCounted<?>) message).release();
}
- messagingService.sendAndReceive(CopycatTransport.toEndpoint(remoteAddress),
- outboundMessageSubject,
+ messagingService.sendAndReceive(endpoint,
+ remoteSubject,
baos.toByteArray(),
context.executor())
.whenComplete((r, e) -> {
@@ -123,20 +111,23 @@
wrappedError = new TransportException(e);
}
}
- handleResponse(r, wrappedError, result, context);
+ handleResponse(r, wrappedError, future);
});
} catch (SerializationException | IOException e) {
- result.completeExceptionally(e);
+ future.completeExceptionally(e);
}
- return result;
+ return future;
}
- private <T> void handleResponse(byte[] response,
- Throwable error,
- CompletableFuture<T> future,
- ThreadContext context) {
+ /**
+ * Handles a response received from the other side of the connection.
+ */
+ private <T> void handleResponse(
+ byte[] response,
+ Throwable error,
+ CompletableFuture<T> future) {
if (error != null) {
- context.execute(() -> future.completeExceptionally(error));
+ future.completeExceptionally(error);
return;
}
checkNotNull(response);
@@ -145,36 +136,54 @@
byte status = (byte) input.read();
if (status == FAILURE) {
Throwable t = context.serializer().readObject(input);
- context.execute(() -> future.completeExceptionally(t));
+ future.completeExceptionally(t);
} else {
- context.execute(() -> {
- try {
- future.complete(context.serializer().readObject(input));
- } catch (SerializationException e) {
- future.completeExceptionally(e);
- }
- });
+ try {
+ future.complete(context.serializer().readObject(input));
+ } catch (SerializationException e) {
+ future.completeExceptionally(e);
+ }
}
} catch (IOException e) {
- context.execute(() -> future.completeExceptionally(e));
+ future.completeExceptionally(e);
}
}
- @Override
- public <T, U> Connection handler(Class<T> type, MessageHandler<T, U> handler) {
- Assert.notNull(type, "type");
- handlers.put(type, new InternalHandler(handler, ThreadContext.currentContextOrThrow()));
- return null;
+ /**
+ * Handles a message sent to the connection.
+ */
+ private CompletableFuture<byte[]> handle(Endpoint sender, byte[] payload) {
+ try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
+ byte type = input.readByte();
+ switch (type) {
+ case MESSAGE:
+ return handleMessage(IOUtils.toByteArray(input));
+ case CLOSE:
+ return handleClose();
+ default:
+ throw new IllegalStateException("Invalid message type");
+ }
+ } catch (IOException e) {
+ Throwables.propagate(e);
+ return null;
+ }
}
- public CompletableFuture<byte[]> handle(byte[] message) {
+ /**
+ * Handles a message from the other side of the connection.
+ */
+ @SuppressWarnings("unchecked")
+ private CompletableFuture<byte[]> handleMessage(byte[] message) {
try {
Object request = context.serializer().readObject(new ByteArrayInputStream(message));
InternalHandler handler = handlers.get(request.getClass());
if (handler == null) {
+ log.warn("No handler registered on connection {}-{} for type {}",
+ partitionId, connectionId, request.getClass());
return Tools.exceptionalFuture(new IllegalStateException(
"No handler registered for " + request.getClass()));
}
+
return handler.handle(request).handle((result, error) -> {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
baos.write(error != null ? FAILURE : SUCCESS);
@@ -190,67 +199,159 @@
}
}
- @Override
- public Listener<Throwable> exceptionListener(Consumer<Throwable> listener) {
- return exceptionListeners.add(listener);
+ /**
+ * Handles a close request from the other side of the connection.
+ */
+ private CompletableFuture<byte[]> handleClose() {
+ CompletableFuture<byte[]> future = new CompletableFuture<>();
+ context.executor().execute(() -> {
+ cleanup();
+ ByteBuffer responseBuffer = ByteBuffer.allocate(1);
+ responseBuffer.put(SUCCESS);
+ future.complete(responseBuffer.array());
+ });
+ return future;
}
@Override
- public Listener<Connection> closeListener(Consumer<Connection> listener) {
- return closeListeners.add(listener);
+ public <T, U> Connection handler(Class<T> type, MessageHandler<T, U> handler) {
+ if (log.isTraceEnabled()) {
+ log.trace("Registered handler on connection {}-{}: {}", partitionId, connectionId, type);
+ }
+ handlers.put(type, new InternalHandler(handler, ThreadContext.currentContextOrThrow()));
+ return this;
+ }
+
+ @Override
+ public Listener<Throwable> exceptionListener(Consumer<Throwable> consumer) {
+ return exceptionListeners.add(consumer);
+ }
+
+ @Override
+ public Listener<Connection> closeListener(Consumer<Connection> consumer) {
+ return closeListeners.add(consumer);
}
@Override
public CompletableFuture<Void> close() {
- closeListeners.forEach(listener -> listener.accept(this));
- if (mode == CopycatTransport.Mode.CLIENT) {
- messagingService.unregisterHandler(inboundMessageSubject);
- }
- return CompletableFuture.completedFuture(null);
+ log.debug("Closing connection {}-{}", partitionId, connectionId);
+
+ ByteBuffer requestBuffer = ByteBuffer.allocate(1);
+ requestBuffer.put(CLOSE);
+
+ ThreadContext context = ThreadContext.currentContextOrThrow();
+ CompletableFuture<Void> future = new CompletableFuture<>();
+ messagingService.sendAndReceive(endpoint, remoteSubject, requestBuffer.array(), context.executor())
+ .whenComplete((payload, error) -> {
+ cleanup();
+ Throwable wrappedError = error;
+ if (error != null) {
+ Throwable rootCause = Throwables.getRootCause(error);
+ if (MessagingException.class.isAssignableFrom(rootCause.getClass())) {
+ wrappedError = new TransportException(error);
+ }
+ future.completeExceptionally(wrappedError);
+ } else {
+ ByteBuffer responseBuffer = ByteBuffer.wrap(payload);
+ if (responseBuffer.get() == SUCCESS) {
+ future.complete(null);
+ } else {
+ future.completeExceptionally(new TransportException("Failed to close connection"));
+ }
+ }
+ });
+ return future;
}
- @Override
- public int hashCode() {
- return Objects.hash(connectionId);
+ /**
+ * Cleans up the connection, unregistering handlers registered on the MessagingService.
+ */
+ private void cleanup() {
+ log.debug("Connection {}-{} closed", partitionId, connectionId);
+ messagingService.unregisterHandler(localSubject);
+ closeListeners.accept(this);
}
- @Override
- public boolean equals(Object other) {
- if (!(other instanceof CopycatTransportConnection)) {
- return false;
- }
- return connectionId == ((CopycatTransportConnection) other).connectionId;
+ /**
+ * Connection mode used to indicate whether this side of the connection is
+ * a client or server.
+ */
+ enum Mode {
+
+ /**
+ * Represents the client side of a bi-directional connection.
+ */
+ CLIENT {
+ @Override
+ String getLocalSubject(PartitionId partitionId, long connectionId) {
+ return String.format("onos-copycat-%s-%d-client", partitionId, connectionId);
+ }
+
+ @Override
+ String getRemoteSubject(PartitionId partitionId, long connectionId) {
+ return String.format("onos-copycat-%s-%d-server", partitionId, connectionId);
+ }
+ },
+
+ /**
+ * Represents the server side of a bi-directional connection.
+ */
+ SERVER {
+ @Override
+ String getLocalSubject(PartitionId partitionId, long connectionId) {
+ return String.format("onos-copycat-%s-%d-server", partitionId, connectionId);
+ }
+
+ @Override
+ String getRemoteSubject(PartitionId partitionId, long connectionId) {
+ return String.format("onos-copycat-%s-%d-client", partitionId, connectionId);
+ }
+ };
+
+ /**
+ * Returns the local messaging service subject for the connection in this mode.
+ * Subjects generated by the connection mode are guaranteed to be globally unique.
+ *
+ * @param partitionId the partition ID to which the connection belongs.
+ * @param connectionId the connection ID.
+ * @return the globally unique local subject for the connection.
+ */
+ abstract String getLocalSubject(PartitionId partitionId, long connectionId);
+
+ /**
+ * Returns the remote messaging service subject for the connection in this mode.
+ * Subjects generated by the connection mode are guaranteed to be globally unique.
+ *
+ * @param partitionId the partition ID to which the connection belongs.
+ * @param connectionId the connection ID.
+ * @return the globally unique remote subject for the connection.
+ */
+ abstract String getRemoteSubject(PartitionId partitionId, long connectionId);
}
- @Override
- public String toString() {
- return MoreObjects.toStringHelper(getClass())
- .add("id", connectionId)
- .toString();
- }
-
- @SuppressWarnings("rawtypes")
- private final class InternalHandler {
-
+ /**
+ * Internal container for a handler/context pair.
+ */
+ private static class InternalHandler {
private final MessageHandler handler;
private final ThreadContext context;
- private InternalHandler(MessageHandler handler, ThreadContext context) {
+ InternalHandler(MessageHandler handler, ThreadContext context) {
this.handler = handler;
this.context = context;
}
@SuppressWarnings("unchecked")
- public CompletableFuture<Object> handle(Object message) {
- CompletableFuture<Object> answer = new CompletableFuture<>();
+ CompletableFuture<Object> handle(Object message) {
+ CompletableFuture<Object> future = new CompletableFuture<>();
context.execute(() -> handler.handle(message).whenComplete((r, e) -> {
if (e != null) {
- answer.completeExceptionally((Throwable) e);
+ future.completeExceptionally((Throwable) e);
} else {
- answer.complete(r);
+ future.complete(r);
}
}));
- return answer;
+ return future;
}
}
}
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportServer.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportServer.java
index ff64112..aee7647 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportServer.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportServer.java
@@ -15,113 +15,99 @@
*/
package org.onosproject.store.primitives.impl;
-import static com.google.common.base.Preconditions.checkNotNull;
-import static org.slf4j.LoggerFactory.getLogger;
-import io.atomix.catalyst.concurrent.CatalystThreadFactory;
-import io.atomix.catalyst.concurrent.SingleThreadContext;
+import com.google.common.collect.Sets;
import io.atomix.catalyst.concurrent.ThreadContext;
import io.atomix.catalyst.transport.Address;
import io.atomix.catalyst.transport.Connection;
import io.atomix.catalyst.transport.Server;
-
-import java.io.ByteArrayInputStream;
-import java.io.DataInputStream;
-import java.io.IOException;
-import java.util.Map;
-import java.util.concurrent.CompletableFuture;
-import java.util.concurrent.Executors;
-import java.util.concurrent.ScheduledExecutorService;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.function.Consumer;
-
-import org.apache.commons.io.IOUtils;
-import org.onlab.util.Tools;
+import org.apache.commons.lang3.RandomUtils;
import org.onosproject.cluster.PartitionId;
import org.onosproject.store.cluster.messaging.MessagingService;
import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
-import com.google.common.collect.Maps;
+import java.nio.ByteBuffer;
+import java.util.Set;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.Consumer;
+
+import static com.google.common.base.MoreObjects.toStringHelper;
+import static com.google.common.base.Preconditions.checkNotNull;
+import static org.onosproject.store.primitives.impl.CopycatTransport.CONNECT;
+import static org.onosproject.store.primitives.impl.CopycatTransport.FAILURE;
+import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
/**
- * {@link Server} implementation for {@link CopycatTransport}.
+ * Copycat transport server implementation.
*/
public class CopycatTransportServer implements Server {
-
- private final Logger log = getLogger(getClass());
- private final AtomicBoolean listening = new AtomicBoolean(false);
- private CompletableFuture<Void> listenFuture = new CompletableFuture<>();
- private final ScheduledExecutorService executorService;
+ private final Logger log = LoggerFactory.getLogger(getClass());
private final PartitionId partitionId;
+ private final String serverSubject;
private final MessagingService messagingService;
- private final String messageSubject;
- private final Map<Long, CopycatTransportConnection> connections = Maps.newConcurrentMap();
+ private final Set<CopycatTransportConnection> connections = Sets.newConcurrentHashSet();
- CopycatTransportServer(PartitionId partitionId, MessagingService messagingService) {
- this.partitionId = checkNotNull(partitionId);
- this.messagingService = checkNotNull(messagingService);
- this.messageSubject = String.format("onos-copycat-%s", partitionId);
- this.executorService = Executors.newScheduledThreadPool(Math.min(4, Runtime.getRuntime().availableProcessors()),
- new CatalystThreadFactory("copycat-server-p" + partitionId + "-%d"));
+ public CopycatTransportServer(PartitionId partitionId, MessagingService messagingService) {
+ this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
+ this.serverSubject = String.format("onos-copycat-%s", partitionId);
+ this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
}
@Override
- public CompletableFuture<Void> listen(Address address, Consumer<Connection> listener) {
- if (listening.compareAndSet(false, true)) {
- ThreadContext context = ThreadContext.currentContextOrThrow();
- listen(address, listener, context);
- }
- return listenFuture;
- }
+ public CompletableFuture<Void> listen(Address address, Consumer<Connection> consumer) {
+ ThreadContext context = ThreadContext.currentContextOrThrow();
+ messagingService.registerHandler(serverSubject, (sender, payload) -> {
- private void listen(Address address, Consumer<Connection> listener, ThreadContext context) {
- messagingService.registerHandler(messageSubject, (sender, payload) -> {
- try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
- long connectionId = input.readLong();
- AtomicBoolean newConnectionCreated = new AtomicBoolean(false);
- CopycatTransportConnection connection = connections.computeIfAbsent(connectionId, k -> {
- newConnectionCreated.set(true);
- CopycatTransportConnection newConnection = new CopycatTransportConnection(connectionId,
- CopycatTransport.Mode.SERVER,
- partitionId,
- CopycatTransport.toAddress(sender),
- messagingService,
- getOrCreateContext(context));
- log.debug("Created new incoming connection {}", connectionId);
- newConnection.closeListener(c -> connections.remove(connectionId, c));
- return newConnection;
- });
- byte[] request = IOUtils.toByteArray(input);
- return CompletableFuture.supplyAsync(
- () -> {
- if (newConnectionCreated.get()) {
- listener.accept(connection);
- }
- return connection;
- }, context.executor()).thenCompose(c -> c.handle(request));
- } catch (IOException e) {
- return Tools.exceptionalFuture(e);
+ // Only connect messages can be sent to the server. Once a connect message
+ // is received, the connection will register a separate handler for messaging.
+ ByteBuffer requestBuffer = ByteBuffer.wrap(payload);
+ if (requestBuffer.get() != CONNECT) {
+ ByteBuffer responseBuffer = ByteBuffer.allocate(1);
+ responseBuffer.put(FAILURE);
+ return CompletableFuture.completedFuture(responseBuffer.array());
}
+
+ // Create the connection and ensure state is cleaned up when the connection is closed.
+ long connectionId = RandomUtils.nextLong(0L, Long.MAX_VALUE);
+ CopycatTransportConnection connection = new CopycatTransportConnection(
+ connectionId,
+ CopycatTransportConnection.Mode.SERVER,
+ partitionId,
+ sender,
+ messagingService,
+ context);
+ connection.closeListener(connections::remove);
+ connections.add(connection);
+
+ CompletableFuture<byte[]> future = new CompletableFuture<>();
+
+ // We need to ensure the connection event is called on the Copycat thread
+ // and that the future is not completed until the Copycat server has been
+ // able to register message handlers, otherwise some messages can be received
+ // prior to any handlers being registered.
+ context.executor().execute(() -> {
+ log.debug("Created connection {}-{}", partitionId, connectionId);
+ consumer.accept(connection);
+
+ ByteBuffer responseBuffer = ByteBuffer.allocate(9);
+ responseBuffer.put(SUCCESS);
+ responseBuffer.putLong(connectionId);
+ future.complete(responseBuffer.array());
+ });
+ return future;
});
- context.execute(() -> {
- listenFuture.complete(null);
- });
+ return CompletableFuture.completedFuture(null);
}
@Override
public CompletableFuture<Void> close() {
- messagingService.unregisterHandler(messageSubject);
- executorService.shutdown();
- return CompletableFuture.completedFuture(null);
+ return CompletableFuture.allOf(connections.stream().map(Connection::close).toArray(CompletableFuture[]::new));
}
- /**
- * Returns the current execution context or creates one.
- */
- private ThreadContext getOrCreateContext(ThreadContext parentContext) {
- ThreadContext context = ThreadContext.currentContext();
- if (context != null) {
- return context;
- }
- return new SingleThreadContext(executorService, parentContext.serializer().clone());
+ @Override
+ public String toString() {
+ return toStringHelper(this)
+ .add("partitionId", partitionId)
+ .toString();
}
}
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/StoragePartition.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/StoragePartition.java
index 7ed275a..a68b793 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/StoragePartition.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/StoragePartition.java
@@ -131,9 +131,7 @@
StoragePartitionServer server = new StoragePartitionServer(toAddress(localNodeId),
this,
serializer,
- () -> new CopycatTransport(CopycatTransport.Mode.SERVER,
- partition.getId(),
- messagingService),
+ () -> new CopycatTransport(partition.getId(), messagingService),
logFolder);
return server.open().thenRun(() -> this.server = server);
}
@@ -150,9 +148,7 @@
StoragePartitionServer server = new StoragePartitionServer(toAddress(localNodeId),
this,
serializer,
- () -> new CopycatTransport(CopycatTransport.Mode.SERVER,
- partition.getId(),
- messagingService),
+ () -> new CopycatTransport(partition.getId(), messagingService),
logFolder);
return server.join(Collections2.transform(otherMembers, this::toAddress)).thenRun(() -> this.server = server);
}
@@ -160,9 +156,7 @@
private CompletableFuture<StoragePartitionClient> openClient() {
client = new StoragePartitionClient(this,
serializer,
- new CopycatTransport(CopycatTransport.Mode.CLIENT,
- partition.getId(),
- messagingService));
+ new CopycatTransport(partition.getId(), messagingService));
return client.open().thenApply(v -> client);
}
diff --git a/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java b/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java
new file mode 100644
index 0000000..21b4d70
--- /dev/null
+++ b/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java
@@ -0,0 +1,360 @@
+/*
+ * Copyright 2017-present Open Networking Laboratory
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.onosproject.store.primitives.impl;
+
+import com.google.common.collect.Lists;
+import io.atomix.catalyst.concurrent.SingleThreadContext;
+import io.atomix.catalyst.concurrent.ThreadContext;
+import io.atomix.catalyst.transport.Address;
+import io.atomix.catalyst.transport.Client;
+import io.atomix.catalyst.transport.Server;
+import io.atomix.catalyst.transport.Transport;
+import io.atomix.copycat.protocol.ConnectRequest;
+import io.atomix.copycat.protocol.ConnectResponse;
+import io.atomix.copycat.protocol.PublishRequest;
+import io.atomix.copycat.protocol.PublishResponse;
+import io.atomix.copycat.protocol.Response;
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import org.onlab.packet.IpAddress;
+import org.onlab.util.Tools;
+import org.onosproject.cluster.PartitionId;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingService;
+
+import java.time.Duration;
+import java.util.Map;
+import java.util.UUID;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.Executor;
+import java.util.concurrent.TimeUnit;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.fail;
+import static org.onlab.junit.TestTools.findAvailablePort;
+
+/**
+ * Copycat transport test.
+ */
+public class CopycatTransportTest {
+
+ private static final String IP_STRING = "127.0.0.1";
+
+ private Endpoint endpoint1 = new Endpoint(IpAddress.valueOf(IP_STRING), 5001);
+ private Endpoint endpoint2 = new Endpoint(IpAddress.valueOf(IP_STRING), 5002);
+
+ private MessagingService service1;
+ private MessagingService service2;
+
+ private Transport clientTransport;
+ private ThreadContext clientContext;
+
+ private Transport serverTransport;
+ private ThreadContext serverContext;
+
+ @Before
+ public void setUp() throws Exception {
+ Map<Endpoint, TestMessagingService> services = new ConcurrentHashMap<>();
+
+ endpoint1 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5001));
+ service1 = new TestMessagingService(endpoint1, services);
+ clientTransport = new CopycatTransport(PartitionId.from(1), service1);
+ clientContext = new SingleThreadContext("client-test-%d", CatalystSerializers.getSerializer());
+
+ endpoint2 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5003));
+ service2 = new TestMessagingService(endpoint2, services);
+ serverTransport = new CopycatTransport(PartitionId.from(1), service2);
+ serverContext = new SingleThreadContext("server-test-%d", CatalystSerializers.getSerializer());
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ if (clientContext != null) {
+ clientContext.close();
+ }
+ if (serverContext != null) {
+ serverContext.close();
+ }
+ }
+
+ /**
+ * Tests sending a message from the client side of a Copycat connection to the server side.
+ */
+ @Test
+ public void testCopycatClientConnectionSend() throws Exception {
+ Client client = clientTransport.client();
+ Server server = serverTransport.server();
+
+ CountDownLatch latch = new CountDownLatch(4);
+ CountDownLatch listenLatch = new CountDownLatch(1);
+ CountDownLatch handlerLatch = new CountDownLatch(1);
+ serverContext.executor().execute(() -> {
+ server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+ serverContext.checkThread();
+ latch.countDown();
+ connection.handler(ConnectRequest.class, request -> {
+ serverContext.checkThread();
+ latch.countDown();
+ return CompletableFuture.completedFuture(ConnectResponse.builder()
+ .withStatus(Response.Status.OK)
+ .withLeader(new Address(IP_STRING, endpoint2.port()))
+ .withMembers(Lists.newArrayList(new Address(IP_STRING, endpoint2.port())))
+ .build());
+ });
+ handlerLatch.countDown();
+ }).thenRun(listenLatch::countDown);
+ });
+
+ listenLatch.await(5, TimeUnit.SECONDS);
+
+ clientContext.executor().execute(() -> {
+ client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+ clientContext.checkThread();
+ latch.countDown();
+ try {
+ handlerLatch.await(5, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ fail();
+ }
+ connection.<ConnectRequest, ConnectResponse>send(ConnectRequest.builder()
+ .withClientId(UUID.randomUUID().toString())
+ .build())
+ .thenAccept(response -> {
+ clientContext.checkThread();
+ assertNotNull(response);
+ assertEquals(Response.Status.OK, response.status());
+ latch.countDown();
+ });
+ });
+ });
+
+ latch.await(5, TimeUnit.SECONDS);
+ assertEquals(0, latch.getCount());
+ }
+
+ /**
+ * Tests sending a message from the server side of a Copycat connection to the client side.
+ */
+ @Test
+ public void testCopycatServerConnectionSend() throws Exception {
+ Client client = clientTransport.client();
+ Server server = serverTransport.server();
+
+ CountDownLatch latch = new CountDownLatch(4);
+ CountDownLatch listenLatch = new CountDownLatch(1);
+ serverContext.executor().execute(() -> {
+ server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+ serverContext.checkThread();
+ latch.countDown();
+ serverContext.schedule(Duration.ofMillis(100), () -> {
+ connection.<PublishRequest, PublishResponse>send(PublishRequest.builder()
+ .withSession(1)
+ .withEventIndex(3)
+ .withPreviousIndex(2)
+ .build())
+ .thenAccept(response -> {
+ serverContext.checkThread();
+ assertEquals(Response.Status.OK, response.status());
+ assertEquals(1, response.index());
+ latch.countDown();
+ });
+ });
+ }).thenRun(listenLatch::countDown);
+ });
+
+ listenLatch.await(5, TimeUnit.SECONDS);
+
+ clientContext.executor().execute(() -> {
+ client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+ clientContext.checkThread();
+ latch.countDown();
+ connection.handler(PublishRequest.class, request -> {
+ clientContext.checkThread();
+ latch.countDown();
+ assertEquals(1, request.session());
+ assertEquals(3, request.eventIndex());
+ assertEquals(2, request.previousIndex());
+ return CompletableFuture.completedFuture(PublishResponse.builder()
+ .withStatus(Response.Status.OK)
+ .withIndex(1)
+ .build());
+ });
+ });
+ });
+
+ latch.await(5, TimeUnit.SECONDS);
+ assertEquals(0, latch.getCount());
+ }
+
+ /**
+ * Tests closing the server side of a Copycat connection.
+ */
+ @Test
+ public void testCopycatClientConnectionClose() throws Exception {
+ Client client = clientTransport.client();
+ Server server = serverTransport.server();
+
+ CountDownLatch latch = new CountDownLatch(5);
+ CountDownLatch listenLatch = new CountDownLatch(1);
+ serverContext.executor().execute(() -> {
+ server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+ serverContext.checkThread();
+ latch.countDown();
+ connection.closeListener(c -> {
+ serverContext.checkThread();
+ latch.countDown();
+ });
+ }).thenRun(listenLatch::countDown);
+ });
+
+ listenLatch.await(5, TimeUnit.SECONDS);
+
+ clientContext.executor().execute(() -> {
+ client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+ clientContext.checkThread();
+ latch.countDown();
+ connection.closeListener(c -> {
+ clientContext.checkThread();
+ latch.countDown();
+ });
+ clientContext.schedule(Duration.ofMillis(100), () -> {
+ connection.close().whenComplete((result, error) -> {
+ clientContext.checkThread();
+ latch.countDown();
+ });
+ });
+ });
+ });
+
+ latch.await(5, TimeUnit.SECONDS);
+ assertEquals(0, latch.getCount());
+ }
+
+ /**
+ * Tests closing the server side of a Copycat connection.
+ */
+ @Test
+ public void testCopycatServerConnectionClose() throws Exception {
+ Client client = clientTransport.client();
+ Server server = serverTransport.server();
+
+ CountDownLatch latch = new CountDownLatch(5);
+ CountDownLatch listenLatch = new CountDownLatch(1);
+ serverContext.executor().execute(() -> {
+ server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+ serverContext.checkThread();
+ latch.countDown();
+ connection.closeListener(c -> {
+ latch.countDown();
+ });
+ serverContext.schedule(Duration.ofMillis(100), () -> {
+ connection.close().whenComplete((result, error) -> {
+ serverContext.checkThread();
+ latch.countDown();
+ });
+ });
+ }).thenRun(listenLatch::countDown);
+ });
+
+ listenLatch.await(5, TimeUnit.SECONDS);
+
+ clientContext.executor().execute(() -> {
+ client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+ clientContext.checkThread();
+ latch.countDown();
+ connection.closeListener(c -> {
+ latch.countDown();
+ });
+ });
+ });
+
+ latch.await(5, TimeUnit.SECONDS);
+ assertEquals(0, latch.getCount());
+ }
+
+ /**
+ * Custom implementation of {@code MessagingService} used for testing. Really, this should
+ * be mocked but suffices for now.
+ */
+ public static final class TestMessagingService implements MessagingService {
+ private final Endpoint endpoint;
+ private final Map<Endpoint, TestMessagingService> services;
+ private final Map<String, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>>> handlers =
+ new ConcurrentHashMap<>();
+
+ TestMessagingService(Endpoint endpoint, Map<Endpoint, TestMessagingService> services) {
+ this.endpoint = endpoint;
+ this.services = services;
+ services.put(endpoint, this);
+ }
+
+ private CompletableFuture<byte[]> handle(Endpoint ep, String type, byte[] message, Executor executor) {
+ BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler = handlers.get(type);
+ if (handler == null) {
+ return Tools.exceptionalFuture(new IllegalStateException());
+ }
+ return handler.apply(ep, message).thenApplyAsync(r -> r, executor);
+ }
+
+ @Override
+ public CompletableFuture<Void> sendAsync(Endpoint ep, String type, byte[] payload) {
+ // Unused for testing
+ return null;
+ }
+
+ @Override
+ public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload) {
+ // Unused for testing
+ return null;
+ }
+
+ @Override
+ public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor) {
+ TestMessagingService service = services.get(ep);
+ if (service == null) {
+ return Tools.exceptionalFuture(new IllegalStateException());
+ }
+ return service.handle(endpoint, type, payload, executor);
+ }
+
+ @Override
+ public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) {
+ // Unused for testing
+ }
+
+ @Override
+ public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) {
+ // Unused for testing
+ }
+
+ @Override
+ public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) {
+ handlers.put(type, handler);
+ }
+
+ @Override
+ public void unregisterHandler(String type) {
+ handlers.remove(type);
+ }
+ }
+
+}
\ No newline at end of file