blob: a3a8539f00b87599dac7035aef516913024317b6 [file] [log] [blame]
/*
* Copyright 2016-present Open Networking Laboratory
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.onosproject.store.primitives.impl;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.SocketException;
import java.nio.ByteBuffer;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Consumer;
import java.util.function.Function;
import com.google.common.base.Throwables;
import io.atomix.catalyst.concurrent.Listener;
import io.atomix.catalyst.concurrent.Listeners;
import io.atomix.catalyst.concurrent.ThreadContext;
import io.atomix.catalyst.serializer.SerializationException;
import io.atomix.catalyst.transport.Connection;
import io.atomix.catalyst.transport.TransportException;
import io.atomix.catalyst.util.reference.ReferenceCounted;
import org.apache.commons.io.IOUtils;
import org.onlab.util.Tools;
import org.onosproject.cluster.PartitionId;
import org.onosproject.store.cluster.messaging.Endpoint;
import org.onosproject.store.cluster.messaging.MessagingException;
import org.onosproject.store.cluster.messaging.MessagingService;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static com.google.common.base.Preconditions.checkNotNull;
import static org.onosproject.store.primitives.impl.CopycatTransport.CLOSE;
import static org.onosproject.store.primitives.impl.CopycatTransport.FAILURE;
import static org.onosproject.store.primitives.impl.CopycatTransport.MESSAGE;
import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS;
/**
* Base Copycat Transport connection.
*/
public class CopycatTransportConnection implements Connection {
private static final int MAX_MESSAGE_SIZE = 1024 * 1024;
private final Logger log = LoggerFactory.getLogger(getClass());
private final long connectionId;
private final String localSubject;
private final String remoteSubject;
private final PartitionId partitionId;
private final Endpoint endpoint;
private final MessagingService messagingService;
private final ThreadContext context;
private final Map<Class, InternalHandler> handlers = new ConcurrentHashMap<>();
private final Listeners<Throwable> exceptionListeners = new Listeners<>();
private final Listeners<Connection> closeListeners = new Listeners<>();
CopycatTransportConnection(
long connectionId,
Mode mode,
PartitionId partitionId,
Endpoint endpoint,
MessagingService messagingService,
ThreadContext context) {
this.connectionId = connectionId;
this.partitionId = checkNotNull(partitionId, "partitionId cannot be null");
this.localSubject = mode.getLocalSubject(partitionId, connectionId);
this.remoteSubject = mode.getRemoteSubject(partitionId, connectionId);
this.endpoint = checkNotNull(endpoint, "endpoint cannot be null");
this.messagingService = checkNotNull(messagingService, "messagingService cannot be null");
this.context = checkNotNull(context, "context cannot be null");
messagingService.registerHandler(localSubject, this::handle);
}
@Override
public CompletableFuture<Void> send(Object message) {
ThreadContext context = ThreadContext.currentContextOrThrow();
CompletableFuture<Void> future = new CompletableFuture<>();
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
DataOutputStream dos = new DataOutputStream(baos);
dos.writeByte(MESSAGE);
context.serializer().writeObject(message, baos);
if (message instanceof ReferenceCounted) {
((ReferenceCounted<?>) message).release();
}
byte[] bytes = baos.toByteArray();
if (bytes.length > MAX_MESSAGE_SIZE) {
throw new IllegalArgumentException(message + " exceeds maximum message size " + MAX_MESSAGE_SIZE);
}
messagingService.sendAsync(endpoint, remoteSubject, bytes)
.whenComplete((r, e) -> {
if (e != null) {
context.executor().execute(() -> future.completeExceptionally(e));
} else {
context.executor().execute(() -> future.complete(null));
}
});
} catch (SerializationException | IOException e) {
future.completeExceptionally(e);
}
return future;
}
@Override
public <T, U> CompletableFuture<U> sendAndReceive(T message) {
ThreadContext context = ThreadContext.currentContextOrThrow();
CompletableFuture<U> future = new CompletableFuture<>();
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
DataOutputStream dos = new DataOutputStream(baos);
dos.writeByte(MESSAGE);
context.serializer().writeObject(message, baos);
if (message instanceof ReferenceCounted) {
((ReferenceCounted<?>) message).release();
}
byte[] bytes = baos.toByteArray();
if (bytes.length > MAX_MESSAGE_SIZE) {
throw new IllegalArgumentException(message + " exceeds maximum message size " + MAX_MESSAGE_SIZE);
}
messagingService.sendAndReceive(endpoint,
remoteSubject,
bytes,
context.executor())
.whenComplete((response, error) -> handleResponse(response, error, future));
} catch (SerializationException | IOException e) {
future.completeExceptionally(e);
}
return future;
}
/**
* Handles a response received from the other side of the connection.
*/
private <T> void handleResponse(
byte[] response,
Throwable error,
CompletableFuture<T> future) {
if (error != null) {
Throwable rootCause = Throwables.getRootCause(error);
if (rootCause instanceof MessagingException.NoRemoteHandler) {
future.completeExceptionally(new TransportException(error));
close(rootCause);
} else if (rootCause instanceof SocketException) {
future.completeExceptionally(new TransportException(error));
} else {
future.completeExceptionally(error);
}
return;
}
checkNotNull(response);
InputStream input = new ByteArrayInputStream(response);
try {
byte status = (byte) input.read();
if (status == FAILURE) {
Throwable t = context.serializer().readObject(input);
future.completeExceptionally(t);
} else {
try {
future.complete(context.serializer().readObject(input));
} catch (SerializationException e) {
future.completeExceptionally(e);
}
}
} catch (IOException e) {
future.completeExceptionally(e);
}
}
/**
* Handles a message sent to the connection.
*/
private CompletableFuture<byte[]> handle(Endpoint sender, byte[] payload) {
try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) {
byte type = input.readByte();
switch (type) {
case MESSAGE:
return handleMessage(IOUtils.toByteArray(input));
case CLOSE:
return handleClose();
default:
throw new IllegalStateException("Invalid message type");
}
} catch (IOException e) {
Throwables.propagate(e);
return null;
}
}
/**
* Handles a message from the other side of the connection.
*/
@SuppressWarnings("unchecked")
private CompletableFuture<byte[]> handleMessage(byte[] message) {
try {
Object request = context.serializer().readObject(new ByteArrayInputStream(message));
InternalHandler handler = handlers.get(request.getClass());
if (handler == null) {
log.warn("No handler registered on connection {}-{} for type {}",
partitionId, connectionId, request.getClass());
return Tools.exceptionalFuture(new IllegalStateException(
"No handler registered for " + request.getClass()));
}
return handler.handle(request).handle((result, error) -> {
try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) {
baos.write(error != null ? FAILURE : SUCCESS);
context.serializer().writeObject(error != null ? error : result, baos);
byte[] bytes = baos.toByteArray();
if (bytes.length > MAX_MESSAGE_SIZE) {
throw new IllegalArgumentException("response exceeds maximum message size " + MAX_MESSAGE_SIZE);
}
return bytes;
} catch (IOException e) {
Throwables.propagate(e);
return null;
}
});
} catch (Exception e) {
return Tools.exceptionalFuture(e);
}
}
/**
* Handles a close request from the other side of the connection.
*/
private CompletableFuture<byte[]> handleClose() {
CompletableFuture<byte[]> future = new CompletableFuture<>();
context.executor().execute(() -> {
close(null);
ByteBuffer responseBuffer = ByteBuffer.allocate(1);
responseBuffer.put(SUCCESS);
future.complete(responseBuffer.array());
});
return future;
}
@Override
public <T, U> Connection handler(Class<T> type, Consumer<T> handler) {
return handler(type, r -> {
handler.accept(r);
return null;
});
}
@Override
public <T, U> Connection handler(Class<T> type, Function<T, CompletableFuture<U>> handler) {
if (log.isTraceEnabled()) {
log.trace("Registered handler on connection {}-{}: {}", partitionId, connectionId, type);
}
handlers.put(type, new InternalHandler(handler, ThreadContext.currentContextOrThrow()));
return this;
}
@Override
public Listener<Throwable> onException(Consumer<Throwable> consumer) {
return exceptionListeners.add(consumer);
}
@Override
public Listener<Connection> onClose(Consumer<Connection> consumer) {
return closeListeners.add(consumer);
}
@Override
public CompletableFuture<Void> close() {
log.debug("Closing connection {}-{}", partitionId, connectionId);
ByteBuffer requestBuffer = ByteBuffer.allocate(1);
requestBuffer.put(CLOSE);
ThreadContext context = ThreadContext.currentContextOrThrow();
CompletableFuture<Void> future = new CompletableFuture<>();
messagingService.sendAndReceive(endpoint, remoteSubject, requestBuffer.array(), context.executor())
.whenComplete((payload, error) -> {
close(error);
Throwable wrappedError = error;
if (error != null) {
Throwable rootCause = Throwables.getRootCause(error);
if (rootCause instanceof MessagingException.NoRemoteHandler) {
wrappedError = new TransportException(error);
}
future.completeExceptionally(wrappedError);
} else {
ByteBuffer responseBuffer = ByteBuffer.wrap(payload);
if (responseBuffer.get() == SUCCESS) {
future.complete(null);
} else {
future.completeExceptionally(new TransportException("Failed to close connection"));
}
}
});
return future;
}
/**
* Cleans up the connection, unregistering handlers registered on the MessagingService.
*/
private void close(Throwable error) {
log.debug("Connection {}-{} closed", partitionId, connectionId);
messagingService.unregisterHandler(localSubject);
if (error != null) {
exceptionListeners.accept(error);
}
closeListeners.accept(this);
}
/**
* Connection mode used to indicate whether this side of the connection is
* a client or server.
*/
enum Mode {
/**
* Represents the client side of a bi-directional connection.
*/
CLIENT {
@Override
String getLocalSubject(PartitionId partitionId, long connectionId) {
return String.format("onos-copycat-%s-%d-client", partitionId, connectionId);
}
@Override
String getRemoteSubject(PartitionId partitionId, long connectionId) {
return String.format("onos-copycat-%s-%d-server", partitionId, connectionId);
}
},
/**
* Represents the server side of a bi-directional connection.
*/
SERVER {
@Override
String getLocalSubject(PartitionId partitionId, long connectionId) {
return String.format("onos-copycat-%s-%d-server", partitionId, connectionId);
}
@Override
String getRemoteSubject(PartitionId partitionId, long connectionId) {
return String.format("onos-copycat-%s-%d-client", partitionId, connectionId);
}
};
/**
* Returns the local messaging service subject for the connection in this mode.
* Subjects generated by the connection mode are guaranteed to be globally unique.
*
* @param partitionId the partition ID to which the connection belongs.
* @param connectionId the connection ID.
* @return the globally unique local subject for the connection.
*/
abstract String getLocalSubject(PartitionId partitionId, long connectionId);
/**
* Returns the remote messaging service subject for the connection in this mode.
* Subjects generated by the connection mode are guaranteed to be globally unique.
*
* @param partitionId the partition ID to which the connection belongs.
* @param connectionId the connection ID.
* @return the globally unique remote subject for the connection.
*/
abstract String getRemoteSubject(PartitionId partitionId, long connectionId);
}
/**
* Internal container for a handler/context pair.
*/
private static class InternalHandler {
private final Function handler;
private final ThreadContext context;
InternalHandler(Function handler, ThreadContext context) {
this.handler = handler;
this.context = context;
}
@SuppressWarnings("unchecked")
CompletableFuture<Object> handle(Object message) {
CompletableFuture<Object> future = new CompletableFuture<>();
context.executor().execute(() -> {
CompletableFuture<Object> responseFuture = (CompletableFuture<Object>) handler.apply(message);
if (responseFuture != null) {
responseFuture.whenComplete((r, e) -> {
if (e != null) {
future.completeExceptionally((Throwable) e);
} else {
future.complete(r);
}
});
}
});
return future;
}
}
}