| /* |
| * Copyright 2016-present Open Networking Laboratory |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.onosproject.store.primitives.impl; |
| |
| import java.io.ByteArrayInputStream; |
| import java.io.ByteArrayOutputStream; |
| import java.io.DataInputStream; |
| import java.io.DataOutputStream; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.net.SocketException; |
| import java.nio.ByteBuffer; |
| import java.util.Map; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.function.Consumer; |
| import java.util.function.Function; |
| |
| import com.google.common.base.Throwables; |
| import io.atomix.catalyst.concurrent.Listener; |
| import io.atomix.catalyst.concurrent.Listeners; |
| import io.atomix.catalyst.concurrent.ThreadContext; |
| import io.atomix.catalyst.serializer.SerializationException; |
| import io.atomix.catalyst.transport.Connection; |
| import io.atomix.catalyst.transport.TransportException; |
| import io.atomix.catalyst.util.reference.ReferenceCounted; |
| import org.apache.commons.io.IOUtils; |
| import org.onlab.util.Tools; |
| import org.onosproject.cluster.PartitionId; |
| import org.onosproject.store.cluster.messaging.Endpoint; |
| import org.onosproject.store.cluster.messaging.MessagingException; |
| import org.onosproject.store.cluster.messaging.MessagingService; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import static com.google.common.base.Preconditions.checkNotNull; |
| import static org.onosproject.store.primitives.impl.CopycatTransport.CLOSE; |
| import static org.onosproject.store.primitives.impl.CopycatTransport.FAILURE; |
| import static org.onosproject.store.primitives.impl.CopycatTransport.MESSAGE; |
| import static org.onosproject.store.primitives.impl.CopycatTransport.SUCCESS; |
| |
| /** |
| * Base Copycat Transport connection. |
| */ |
| public class CopycatTransportConnection implements Connection { |
| private final Logger log = LoggerFactory.getLogger(getClass()); |
| private final long connectionId; |
| private final String localSubject; |
| private final String remoteSubject; |
| private final PartitionId partitionId; |
| private final Endpoint endpoint; |
| private final MessagingService messagingService; |
| private final ThreadContext context; |
| private final Map<Class, InternalHandler> handlers = new ConcurrentHashMap<>(); |
| private final Listeners<Throwable> exceptionListeners = new Listeners<>(); |
| private final Listeners<Connection> closeListeners = new Listeners<>(); |
| |
| CopycatTransportConnection( |
| long connectionId, |
| Mode mode, |
| PartitionId partitionId, |
| Endpoint endpoint, |
| MessagingService messagingService, |
| ThreadContext context) { |
| this.connectionId = connectionId; |
| this.partitionId = checkNotNull(partitionId, "partitionId cannot be null"); |
| this.localSubject = mode.getLocalSubject(partitionId, connectionId); |
| this.remoteSubject = mode.getRemoteSubject(partitionId, connectionId); |
| this.endpoint = checkNotNull(endpoint, "endpoint cannot be null"); |
| this.messagingService = checkNotNull(messagingService, "messagingService cannot be null"); |
| this.context = checkNotNull(context, "context cannot be null"); |
| messagingService.registerHandler(localSubject, this::handle); |
| } |
| |
| @Override |
| public CompletableFuture<Void> send(Object message) { |
| ThreadContext context = ThreadContext.currentContextOrThrow(); |
| CompletableFuture<Void> future = new CompletableFuture<>(); |
| try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { |
| DataOutputStream dos = new DataOutputStream(baos); |
| dos.writeByte(MESSAGE); |
| context.serializer().writeObject(message, baos); |
| if (message instanceof ReferenceCounted) { |
| ((ReferenceCounted<?>) message).release(); |
| } |
| |
| messagingService.sendAsync(endpoint, remoteSubject, baos.toByteArray()) |
| .whenComplete((r, e) -> { |
| if (e != null) { |
| context.executor().execute(() -> future.completeExceptionally(e)); |
| } else { |
| context.executor().execute(() -> future.complete(null)); |
| } |
| }); |
| } catch (SerializationException | IOException e) { |
| future.completeExceptionally(e); |
| } |
| return future; |
| } |
| |
| @Override |
| public <T, U> CompletableFuture<U> sendAndReceive(T message) { |
| ThreadContext context = ThreadContext.currentContextOrThrow(); |
| CompletableFuture<U> future = new CompletableFuture<>(); |
| try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { |
| DataOutputStream dos = new DataOutputStream(baos); |
| dos.writeByte(MESSAGE); |
| context.serializer().writeObject(message, baos); |
| if (message instanceof ReferenceCounted) { |
| ((ReferenceCounted<?>) message).release(); |
| } |
| messagingService.sendAndReceive(endpoint, |
| remoteSubject, |
| baos.toByteArray(), |
| context.executor()) |
| .whenComplete((response, error) -> handleResponse(response, error, future)); |
| } catch (SerializationException | IOException e) { |
| future.completeExceptionally(e); |
| } |
| return future; |
| } |
| |
| /** |
| * Handles a response received from the other side of the connection. |
| */ |
| private <T> void handleResponse( |
| byte[] response, |
| Throwable error, |
| CompletableFuture<T> future) { |
| if (error != null) { |
| Throwable rootCause = Throwables.getRootCause(error); |
| if (rootCause instanceof MessagingException || rootCause instanceof SocketException) { |
| future.completeExceptionally(new TransportException(error)); |
| if (rootCause instanceof MessagingException.NoRemoteHandler) { |
| close(rootCause); |
| } |
| } else { |
| future.completeExceptionally(error); |
| } |
| return; |
| } |
| |
| checkNotNull(response); |
| InputStream input = new ByteArrayInputStream(response); |
| try { |
| byte status = (byte) input.read(); |
| if (status == FAILURE) { |
| Throwable t = context.serializer().readObject(input); |
| future.completeExceptionally(t); |
| } else { |
| try { |
| future.complete(context.serializer().readObject(input)); |
| } catch (SerializationException e) { |
| future.completeExceptionally(e); |
| } |
| } |
| } catch (IOException e) { |
| future.completeExceptionally(e); |
| } |
| } |
| |
| /** |
| * Handles a message sent to the connection. |
| */ |
| private CompletableFuture<byte[]> handle(Endpoint sender, byte[] payload) { |
| try (DataInputStream input = new DataInputStream(new ByteArrayInputStream(payload))) { |
| byte type = input.readByte(); |
| switch (type) { |
| case MESSAGE: |
| return handleMessage(IOUtils.toByteArray(input)); |
| case CLOSE: |
| return handleClose(); |
| default: |
| throw new IllegalStateException("Invalid message type"); |
| } |
| } catch (IOException e) { |
| Throwables.propagate(e); |
| return null; |
| } |
| } |
| |
| /** |
| * Handles a message from the other side of the connection. |
| */ |
| @SuppressWarnings("unchecked") |
| private CompletableFuture<byte[]> handleMessage(byte[] message) { |
| try { |
| Object request = context.serializer().readObject(new ByteArrayInputStream(message)); |
| InternalHandler handler = handlers.get(request.getClass()); |
| if (handler == null) { |
| log.warn("No handler registered on connection {}-{} for type {}", |
| partitionId, connectionId, request.getClass()); |
| return Tools.exceptionalFuture(new IllegalStateException( |
| "No handler registered for " + request.getClass())); |
| } |
| |
| return handler.handle(request).handle((result, error) -> { |
| try (ByteArrayOutputStream baos = new ByteArrayOutputStream()) { |
| baos.write(error != null ? FAILURE : SUCCESS); |
| context.serializer().writeObject(error != null ? error : result, baos); |
| return baos.toByteArray(); |
| } catch (IOException e) { |
| Throwables.propagate(e); |
| return null; |
| } |
| }); |
| } catch (Exception e) { |
| return Tools.exceptionalFuture(e); |
| } |
| } |
| |
| /** |
| * Handles a close request from the other side of the connection. |
| */ |
| private CompletableFuture<byte[]> handleClose() { |
| CompletableFuture<byte[]> future = new CompletableFuture<>(); |
| context.executor().execute(() -> { |
| close(null); |
| ByteBuffer responseBuffer = ByteBuffer.allocate(1); |
| responseBuffer.put(SUCCESS); |
| future.complete(responseBuffer.array()); |
| }); |
| return future; |
| } |
| |
| @Override |
| public <T, U> Connection handler(Class<T> type, Consumer<T> handler) { |
| return handler(type, r -> { |
| handler.accept(r); |
| return null; |
| }); |
| } |
| |
| @Override |
| public <T, U> Connection handler(Class<T> type, Function<T, CompletableFuture<U>> handler) { |
| if (log.isTraceEnabled()) { |
| log.trace("Registered handler on connection {}-{}: {}", partitionId, connectionId, type); |
| } |
| handlers.put(type, new InternalHandler(handler, ThreadContext.currentContextOrThrow())); |
| return this; |
| } |
| |
| @Override |
| public Listener<Throwable> onException(Consumer<Throwable> consumer) { |
| return exceptionListeners.add(consumer); |
| } |
| |
| @Override |
| public Listener<Connection> onClose(Consumer<Connection> consumer) { |
| return closeListeners.add(consumer); |
| } |
| |
| @Override |
| public CompletableFuture<Void> close() { |
| log.debug("Closing connection {}-{}", partitionId, connectionId); |
| |
| ByteBuffer requestBuffer = ByteBuffer.allocate(1); |
| requestBuffer.put(CLOSE); |
| |
| ThreadContext context = ThreadContext.currentContextOrThrow(); |
| CompletableFuture<Void> future = new CompletableFuture<>(); |
| messagingService.sendAndReceive(endpoint, remoteSubject, requestBuffer.array(), context.executor()) |
| .whenComplete((payload, error) -> { |
| close(error); |
| Throwable wrappedError = error; |
| if (error != null) { |
| Throwable rootCause = Throwables.getRootCause(error); |
| if (MessagingException.class.isAssignableFrom(rootCause.getClass())) { |
| wrappedError = new TransportException(error); |
| } |
| future.completeExceptionally(wrappedError); |
| } else { |
| ByteBuffer responseBuffer = ByteBuffer.wrap(payload); |
| if (responseBuffer.get() == SUCCESS) { |
| future.complete(null); |
| } else { |
| future.completeExceptionally(new TransportException("Failed to close connection")); |
| } |
| } |
| }); |
| return future; |
| } |
| |
| /** |
| * Cleans up the connection, unregistering handlers registered on the MessagingService. |
| */ |
| private void close(Throwable error) { |
| log.debug("Connection {}-{} closed", partitionId, connectionId); |
| messagingService.unregisterHandler(localSubject); |
| if (error != null) { |
| exceptionListeners.accept(error); |
| } |
| closeListeners.accept(this); |
| } |
| |
| /** |
| * Connection mode used to indicate whether this side of the connection is |
| * a client or server. |
| */ |
| enum Mode { |
| |
| /** |
| * Represents the client side of a bi-directional connection. |
| */ |
| CLIENT { |
| @Override |
| String getLocalSubject(PartitionId partitionId, long connectionId) { |
| return String.format("onos-copycat-%s-%d-client", partitionId, connectionId); |
| } |
| |
| @Override |
| String getRemoteSubject(PartitionId partitionId, long connectionId) { |
| return String.format("onos-copycat-%s-%d-server", partitionId, connectionId); |
| } |
| }, |
| |
| /** |
| * Represents the server side of a bi-directional connection. |
| */ |
| SERVER { |
| @Override |
| String getLocalSubject(PartitionId partitionId, long connectionId) { |
| return String.format("onos-copycat-%s-%d-server", partitionId, connectionId); |
| } |
| |
| @Override |
| String getRemoteSubject(PartitionId partitionId, long connectionId) { |
| return String.format("onos-copycat-%s-%d-client", partitionId, connectionId); |
| } |
| }; |
| |
| /** |
| * Returns the local messaging service subject for the connection in this mode. |
| * Subjects generated by the connection mode are guaranteed to be globally unique. |
| * |
| * @param partitionId the partition ID to which the connection belongs. |
| * @param connectionId the connection ID. |
| * @return the globally unique local subject for the connection. |
| */ |
| abstract String getLocalSubject(PartitionId partitionId, long connectionId); |
| |
| /** |
| * Returns the remote messaging service subject for the connection in this mode. |
| * Subjects generated by the connection mode are guaranteed to be globally unique. |
| * |
| * @param partitionId the partition ID to which the connection belongs. |
| * @param connectionId the connection ID. |
| * @return the globally unique remote subject for the connection. |
| */ |
| abstract String getRemoteSubject(PartitionId partitionId, long connectionId); |
| } |
| |
| /** |
| * Internal container for a handler/context pair. |
| */ |
| private static class InternalHandler { |
| private final Function handler; |
| private final ThreadContext context; |
| |
| InternalHandler(Function handler, ThreadContext context) { |
| this.handler = handler; |
| this.context = context; |
| } |
| |
| @SuppressWarnings("unchecked") |
| CompletableFuture<Object> handle(Object message) { |
| CompletableFuture<Object> future = new CompletableFuture<>(); |
| context.executor().execute(() -> { |
| CompletableFuture<Object> responseFuture = (CompletableFuture<Object>) handler.apply(message); |
| if (responseFuture != null) { |
| responseFuture.whenComplete((r, e) -> { |
| if (e != null) { |
| future.completeExceptionally((Throwable) e); |
| } else { |
| future.complete(r); |
| } |
| }); |
| } |
| }); |
| return future; |
| } |
| } |
| } |