blob: de76d0dd6ac32c03b4f7cdaf2a57ba737830f387 [file] [log] [blame]
/*
* Copyright 2017-present Open Networking Laboratory
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.onosproject.store.primitives.impl;
import java.time.Duration;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.function.BiConsumer;
import java.util.function.BiFunction;
import com.google.common.collect.Lists;
import io.atomix.catalyst.concurrent.SingleThreadContext;
import io.atomix.catalyst.concurrent.ThreadContext;
import io.atomix.catalyst.transport.Address;
import io.atomix.catalyst.transport.Client;
import io.atomix.catalyst.transport.Server;
import io.atomix.catalyst.transport.Transport;
import io.atomix.copycat.protocol.ConnectRequest;
import io.atomix.copycat.protocol.ConnectResponse;
import io.atomix.copycat.protocol.Response;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.onlab.packet.IpAddress;
import org.onlab.util.Tools;
import org.onosproject.cluster.PartitionId;
import org.onosproject.store.cluster.messaging.Endpoint;
import org.onosproject.store.cluster.messaging.MessagingService;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;
import static org.onlab.junit.TestTools.findAvailablePort;
/**
* Copycat transport test.
*/
public class CopycatTransportTest {
private static final String IP_STRING = "127.0.0.1";
private Endpoint endpoint1 = new Endpoint(IpAddress.valueOf(IP_STRING), 5001);
private Endpoint endpoint2 = new Endpoint(IpAddress.valueOf(IP_STRING), 5002);
private MessagingService service1;
private MessagingService service2;
private Transport clientTransport;
private ThreadContext clientContext;
private Transport serverTransport;
private ThreadContext serverContext;
@Before
public void setUp() throws Exception {
Map<Endpoint, TestMessagingService> services = new ConcurrentHashMap<>();
endpoint1 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5001));
service1 = new TestMessagingService(endpoint1, services);
clientTransport = new CopycatTransport(PartitionId.from(1), service1);
clientContext = new SingleThreadContext("client-test-%d", CatalystSerializers.getSerializer());
endpoint2 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5003));
service2 = new TestMessagingService(endpoint2, services);
serverTransport = new CopycatTransport(PartitionId.from(1), service2);
serverContext = new SingleThreadContext("server-test-%d", CatalystSerializers.getSerializer());
}
@After
public void tearDown() throws Exception {
if (clientContext != null) {
clientContext.close();
}
if (serverContext != null) {
serverContext.close();
}
}
/**
* Tests sending a message from the client side of a Copycat connection to the server side.
*/
@Test
public void testCopycatClientConnectionSend() throws Exception {
Client client = clientTransport.client();
Server server = serverTransport.server();
CountDownLatch latch = new CountDownLatch(4);
CountDownLatch listenLatch = new CountDownLatch(1);
CountDownLatch handlerLatch = new CountDownLatch(1);
serverContext.executor().execute(() -> {
server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
serverContext.checkThread();
latch.countDown();
connection.handler(ConnectRequest.class, request -> {
serverContext.checkThread();
latch.countDown();
return CompletableFuture.completedFuture(ConnectResponse.builder()
.withStatus(Response.Status.OK)
.withLeader(new Address(IP_STRING, endpoint2.port()))
.withMembers(Lists.newArrayList(new Address(IP_STRING, endpoint2.port())))
.build());
});
handlerLatch.countDown();
}).thenRun(listenLatch::countDown);
});
listenLatch.await(5, TimeUnit.SECONDS);
clientContext.executor().execute(() -> {
client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
clientContext.checkThread();
latch.countDown();
try {
handlerLatch.await(5, TimeUnit.SECONDS);
} catch (InterruptedException e) {
fail();
}
connection.<ConnectRequest, ConnectResponse>sendAndReceive(ConnectRequest.builder()
.withClientId(UUID.randomUUID().toString())
.build())
.thenAccept(response -> {
clientContext.checkThread();
assertNotNull(response);
assertEquals(Response.Status.OK, response.status());
latch.countDown();
});
});
});
latch.await(5, TimeUnit.SECONDS);
assertEquals(0, latch.getCount());
}
/**
* Tests sending a message from the server side of a Copycat connection to the client side.
*/
@Test
public void testCopycatServerConnectionSend() throws Exception {
Client client = clientTransport.client();
Server server = serverTransport.server();
CountDownLatch latch = new CountDownLatch(4);
CountDownLatch listenLatch = new CountDownLatch(1);
serverContext.executor().execute(() -> {
server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
serverContext.checkThread();
latch.countDown();
serverContext.schedule(Duration.ofMillis(100), () -> {
connection.<ConnectRequest, ConnectResponse>sendAndReceive(ConnectRequest.builder()
.withClientId("foo")
.build())
.thenAccept(response -> {
serverContext.checkThread();
assertEquals(Response.Status.OK, response.status());
latch.countDown();
});
});
}).thenRun(listenLatch::countDown);
});
listenLatch.await(5, TimeUnit.SECONDS);
clientContext.executor().execute(() -> {
client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
clientContext.checkThread();
latch.countDown();
connection.handler(ConnectRequest.class, request -> {
clientContext.checkThread();
latch.countDown();
assertEquals("foo", request.client());
return CompletableFuture.completedFuture(ConnectResponse.builder()
.withStatus(Response.Status.OK)
.withLeader(new Address(IP_STRING, endpoint2.port()))
.withMembers(Lists.newArrayList(new Address(IP_STRING, endpoint2.port())))
.build());
});
});
});
latch.await(5, TimeUnit.SECONDS);
assertEquals(0, latch.getCount());
}
/**
* Tests closing the server side of a Copycat connection.
*/
@Test
public void testCopycatClientConnectionClose() throws Exception {
Client client = clientTransport.client();
Server server = serverTransport.server();
CountDownLatch latch = new CountDownLatch(5);
CountDownLatch listenLatch = new CountDownLatch(1);
serverContext.executor().execute(() -> {
server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
serverContext.checkThread();
latch.countDown();
connection.onClose(c -> {
serverContext.checkThread();
latch.countDown();
});
}).thenRun(listenLatch::countDown);
});
listenLatch.await(5, TimeUnit.SECONDS);
clientContext.executor().execute(() -> {
client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
clientContext.checkThread();
latch.countDown();
connection.onClose(c -> {
clientContext.checkThread();
latch.countDown();
});
clientContext.schedule(Duration.ofMillis(100), () -> {
connection.close().whenComplete((result, error) -> {
clientContext.checkThread();
latch.countDown();
});
});
});
});
latch.await(5, TimeUnit.SECONDS);
assertEquals(0, latch.getCount());
}
/**
* Tests closing the server side of a Copycat connection.
*/
@Test
public void testCopycatServerConnectionClose() throws Exception {
Client client = clientTransport.client();
Server server = serverTransport.server();
CountDownLatch latch = new CountDownLatch(5);
CountDownLatch listenLatch = new CountDownLatch(1);
serverContext.executor().execute(() -> {
server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
serverContext.checkThread();
latch.countDown();
connection.onClose(c -> {
latch.countDown();
});
serverContext.schedule(Duration.ofMillis(100), () -> {
connection.close().whenComplete((result, error) -> {
serverContext.checkThread();
latch.countDown();
});
});
}).thenRun(listenLatch::countDown);
});
listenLatch.await(5, TimeUnit.SECONDS);
clientContext.executor().execute(() -> {
client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
clientContext.checkThread();
latch.countDown();
connection.onClose(c -> {
latch.countDown();
});
});
});
latch.await(5, TimeUnit.SECONDS);
assertEquals(0, latch.getCount());
}
/**
* Custom implementation of {@code MessagingService} used for testing. Really, this should
* be mocked but suffices for now.
*/
public static final class TestMessagingService implements MessagingService {
private final Endpoint endpoint;
private final Map<Endpoint, TestMessagingService> services;
private final Map<String, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>>> handlers =
new ConcurrentHashMap<>();
TestMessagingService(Endpoint endpoint, Map<Endpoint, TestMessagingService> services) {
this.endpoint = endpoint;
this.services = services;
services.put(endpoint, this);
}
private CompletableFuture<byte[]> handle(Endpoint ep, String type, byte[] message, Executor executor) {
BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler = handlers.get(type);
if (handler == null) {
return Tools.exceptionalFuture(new IllegalStateException());
}
return handler.apply(ep, message).thenApplyAsync(r -> r, executor);
}
@Override
public CompletableFuture<Void> sendAsync(Endpoint ep, String type, byte[] payload) {
// Unused for testing
return null;
}
@Override
public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload) {
// Unused for testing
return null;
}
@Override
public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor) {
TestMessagingService service = services.get(ep);
if (service == null) {
return Tools.exceptionalFuture(new IllegalStateException());
}
return service.handle(endpoint, type, payload, executor);
}
@Override
public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) {
// Unused for testing
}
@Override
public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) {
// Unused for testing
}
@Override
public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) {
handlers.put(type, handler);
}
@Override
public void unregisterHandler(String type) {
handlers.remove(type);
}
}
}