| /* |
| * Copyright 2017-present Open Networking Laboratory |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.onosproject.store.primitives.impl; |
| |
| import com.google.common.collect.Lists; |
| import io.atomix.catalyst.concurrent.SingleThreadContext; |
| import io.atomix.catalyst.concurrent.ThreadContext; |
| import io.atomix.catalyst.transport.Address; |
| import io.atomix.catalyst.transport.Client; |
| import io.atomix.catalyst.transport.Server; |
| import io.atomix.catalyst.transport.Transport; |
| import io.atomix.copycat.protocol.ConnectRequest; |
| import io.atomix.copycat.protocol.ConnectResponse; |
| import io.atomix.copycat.protocol.PublishRequest; |
| import io.atomix.copycat.protocol.PublishResponse; |
| import io.atomix.copycat.protocol.Response; |
| import org.junit.After; |
| import org.junit.Before; |
| import org.junit.Test; |
| import org.onlab.packet.IpAddress; |
| import org.onlab.util.Tools; |
| import org.onosproject.cluster.PartitionId; |
| import org.onosproject.store.cluster.messaging.Endpoint; |
| import org.onosproject.store.cluster.messaging.MessagingService; |
| |
| import java.time.Duration; |
| import java.util.Map; |
| import java.util.UUID; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.CountDownLatch; |
| import java.util.concurrent.Executor; |
| import java.util.concurrent.TimeUnit; |
| import java.util.function.BiConsumer; |
| import java.util.function.BiFunction; |
| |
| import static org.junit.Assert.assertEquals; |
| import static org.junit.Assert.assertNotNull; |
| import static org.junit.Assert.fail; |
| import static org.onlab.junit.TestTools.findAvailablePort; |
| |
| /** |
| * Copycat transport test. |
| */ |
| public class CopycatTransportTest { |
| |
| private static final String IP_STRING = "127.0.0.1"; |
| |
| private Endpoint endpoint1 = new Endpoint(IpAddress.valueOf(IP_STRING), 5001); |
| private Endpoint endpoint2 = new Endpoint(IpAddress.valueOf(IP_STRING), 5002); |
| |
| private MessagingService service1; |
| private MessagingService service2; |
| |
| private Transport clientTransport; |
| private ThreadContext clientContext; |
| |
| private Transport serverTransport; |
| private ThreadContext serverContext; |
| |
| @Before |
| public void setUp() throws Exception { |
| Map<Endpoint, TestMessagingService> services = new ConcurrentHashMap<>(); |
| |
| endpoint1 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5001)); |
| service1 = new TestMessagingService(endpoint1, services); |
| clientTransport = new CopycatTransport(PartitionId.from(1), service1); |
| clientContext = new SingleThreadContext("client-test-%d", CatalystSerializers.getSerializer()); |
| |
| endpoint2 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5003)); |
| service2 = new TestMessagingService(endpoint2, services); |
| serverTransport = new CopycatTransport(PartitionId.from(1), service2); |
| serverContext = new SingleThreadContext("server-test-%d", CatalystSerializers.getSerializer()); |
| } |
| |
| @After |
| public void tearDown() throws Exception { |
| if (clientContext != null) { |
| clientContext.close(); |
| } |
| if (serverContext != null) { |
| serverContext.close(); |
| } |
| } |
| |
| /** |
| * Tests sending a message from the client side of a Copycat connection to the server side. |
| */ |
| @Test |
| public void testCopycatClientConnectionSend() throws Exception { |
| Client client = clientTransport.client(); |
| Server server = serverTransport.server(); |
| |
| CountDownLatch latch = new CountDownLatch(4); |
| CountDownLatch listenLatch = new CountDownLatch(1); |
| CountDownLatch handlerLatch = new CountDownLatch(1); |
| serverContext.executor().execute(() -> { |
| server.listen(new Address(IP_STRING, endpoint2.port()), connection -> { |
| serverContext.checkThread(); |
| latch.countDown(); |
| connection.handler(ConnectRequest.class, request -> { |
| serverContext.checkThread(); |
| latch.countDown(); |
| return CompletableFuture.completedFuture(ConnectResponse.builder() |
| .withStatus(Response.Status.OK) |
| .withLeader(new Address(IP_STRING, endpoint2.port())) |
| .withMembers(Lists.newArrayList(new Address(IP_STRING, endpoint2.port()))) |
| .build()); |
| }); |
| handlerLatch.countDown(); |
| }).thenRun(listenLatch::countDown); |
| }); |
| |
| listenLatch.await(5, TimeUnit.SECONDS); |
| |
| clientContext.executor().execute(() -> { |
| client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> { |
| clientContext.checkThread(); |
| latch.countDown(); |
| try { |
| handlerLatch.await(5, TimeUnit.SECONDS); |
| } catch (InterruptedException e) { |
| fail(); |
| } |
| connection.<ConnectRequest, ConnectResponse>send(ConnectRequest.builder() |
| .withClientId(UUID.randomUUID().toString()) |
| .build()) |
| .thenAccept(response -> { |
| clientContext.checkThread(); |
| assertNotNull(response); |
| assertEquals(Response.Status.OK, response.status()); |
| latch.countDown(); |
| }); |
| }); |
| }); |
| |
| latch.await(5, TimeUnit.SECONDS); |
| assertEquals(0, latch.getCount()); |
| } |
| |
| /** |
| * Tests sending a message from the server side of a Copycat connection to the client side. |
| */ |
| @Test |
| public void testCopycatServerConnectionSend() throws Exception { |
| Client client = clientTransport.client(); |
| Server server = serverTransport.server(); |
| |
| CountDownLatch latch = new CountDownLatch(4); |
| CountDownLatch listenLatch = new CountDownLatch(1); |
| serverContext.executor().execute(() -> { |
| server.listen(new Address(IP_STRING, endpoint2.port()), connection -> { |
| serverContext.checkThread(); |
| latch.countDown(); |
| serverContext.schedule(Duration.ofMillis(100), () -> { |
| connection.<PublishRequest, PublishResponse>send(PublishRequest.builder() |
| .withSession(1) |
| .withEventIndex(3) |
| .withPreviousIndex(2) |
| .build()) |
| .thenAccept(response -> { |
| serverContext.checkThread(); |
| assertEquals(Response.Status.OK, response.status()); |
| assertEquals(1, response.index()); |
| latch.countDown(); |
| }); |
| }); |
| }).thenRun(listenLatch::countDown); |
| }); |
| |
| listenLatch.await(5, TimeUnit.SECONDS); |
| |
| clientContext.executor().execute(() -> { |
| client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> { |
| clientContext.checkThread(); |
| latch.countDown(); |
| connection.handler(PublishRequest.class, request -> { |
| clientContext.checkThread(); |
| latch.countDown(); |
| assertEquals(1, request.session()); |
| assertEquals(3, request.eventIndex()); |
| assertEquals(2, request.previousIndex()); |
| return CompletableFuture.completedFuture(PublishResponse.builder() |
| .withStatus(Response.Status.OK) |
| .withIndex(1) |
| .build()); |
| }); |
| }); |
| }); |
| |
| latch.await(5, TimeUnit.SECONDS); |
| assertEquals(0, latch.getCount()); |
| } |
| |
| /** |
| * Tests closing the server side of a Copycat connection. |
| */ |
| @Test |
| public void testCopycatClientConnectionClose() throws Exception { |
| Client client = clientTransport.client(); |
| Server server = serverTransport.server(); |
| |
| CountDownLatch latch = new CountDownLatch(5); |
| CountDownLatch listenLatch = new CountDownLatch(1); |
| serverContext.executor().execute(() -> { |
| server.listen(new Address(IP_STRING, endpoint2.port()), connection -> { |
| serverContext.checkThread(); |
| latch.countDown(); |
| connection.closeListener(c -> { |
| serverContext.checkThread(); |
| latch.countDown(); |
| }); |
| }).thenRun(listenLatch::countDown); |
| }); |
| |
| listenLatch.await(5, TimeUnit.SECONDS); |
| |
| clientContext.executor().execute(() -> { |
| client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> { |
| clientContext.checkThread(); |
| latch.countDown(); |
| connection.closeListener(c -> { |
| clientContext.checkThread(); |
| latch.countDown(); |
| }); |
| clientContext.schedule(Duration.ofMillis(100), () -> { |
| connection.close().whenComplete((result, error) -> { |
| clientContext.checkThread(); |
| latch.countDown(); |
| }); |
| }); |
| }); |
| }); |
| |
| latch.await(5, TimeUnit.SECONDS); |
| assertEquals(0, latch.getCount()); |
| } |
| |
| /** |
| * Tests closing the server side of a Copycat connection. |
| */ |
| @Test |
| public void testCopycatServerConnectionClose() throws Exception { |
| Client client = clientTransport.client(); |
| Server server = serverTransport.server(); |
| |
| CountDownLatch latch = new CountDownLatch(5); |
| CountDownLatch listenLatch = new CountDownLatch(1); |
| serverContext.executor().execute(() -> { |
| server.listen(new Address(IP_STRING, endpoint2.port()), connection -> { |
| serverContext.checkThread(); |
| latch.countDown(); |
| connection.closeListener(c -> { |
| latch.countDown(); |
| }); |
| serverContext.schedule(Duration.ofMillis(100), () -> { |
| connection.close().whenComplete((result, error) -> { |
| serverContext.checkThread(); |
| latch.countDown(); |
| }); |
| }); |
| }).thenRun(listenLatch::countDown); |
| }); |
| |
| listenLatch.await(5, TimeUnit.SECONDS); |
| |
| clientContext.executor().execute(() -> { |
| client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> { |
| clientContext.checkThread(); |
| latch.countDown(); |
| connection.closeListener(c -> { |
| latch.countDown(); |
| }); |
| }); |
| }); |
| |
| latch.await(5, TimeUnit.SECONDS); |
| assertEquals(0, latch.getCount()); |
| } |
| |
| /** |
| * Custom implementation of {@code MessagingService} used for testing. Really, this should |
| * be mocked but suffices for now. |
| */ |
| public static final class TestMessagingService implements MessagingService { |
| private final Endpoint endpoint; |
| private final Map<Endpoint, TestMessagingService> services; |
| private final Map<String, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>>> handlers = |
| new ConcurrentHashMap<>(); |
| |
| TestMessagingService(Endpoint endpoint, Map<Endpoint, TestMessagingService> services) { |
| this.endpoint = endpoint; |
| this.services = services; |
| services.put(endpoint, this); |
| } |
| |
| private CompletableFuture<byte[]> handle(Endpoint ep, String type, byte[] message, Executor executor) { |
| BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler = handlers.get(type); |
| if (handler == null) { |
| return Tools.exceptionalFuture(new IllegalStateException()); |
| } |
| return handler.apply(ep, message).thenApplyAsync(r -> r, executor); |
| } |
| |
| @Override |
| public CompletableFuture<Void> sendAsync(Endpoint ep, String type, byte[] payload) { |
| // Unused for testing |
| return null; |
| } |
| |
| @Override |
| public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload) { |
| // Unused for testing |
| return null; |
| } |
| |
| @Override |
| public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor) { |
| TestMessagingService service = services.get(ep); |
| if (service == null) { |
| return Tools.exceptionalFuture(new IllegalStateException()); |
| } |
| return service.handle(endpoint, type, payload, executor); |
| } |
| |
| @Override |
| public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) { |
| // Unused for testing |
| } |
| |
| @Override |
| public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) { |
| // Unused for testing |
| } |
| |
| @Override |
| public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) { |
| handlers.put(type, handler); |
| } |
| |
| @Override |
| public void unregisterHandler(String type) { |
| handlers.remove(type); |
| } |
| } |
| |
| } |