[ONOS-5992] Ensure Copycat connections are closed when no remote handler is registered

Change-Id: Iec17fd09f0d715dbbe08c604057aeb00d677b939
diff --git a/core/store/dist/src/test/java/org/onosproject/store/cluster/messaging/impl/NettyMessagingManagerTest.java b/core/store/dist/src/test/java/org/onosproject/store/cluster/messaging/impl/NettyMessagingManagerTest.java
index 4122888..9c59d65 100644
--- a/core/store/dist/src/test/java/org/onosproject/store/cluster/messaging/impl/NettyMessagingManagerTest.java
+++ b/core/store/dist/src/test/java/org/onosproject/store/cluster/messaging/impl/NettyMessagingManagerTest.java
@@ -33,6 +33,7 @@
 import org.onosproject.net.provider.ProviderId;
 import org.onosproject.store.cluster.messaging.Endpoint;
 
+import java.net.ConnectException;
 import java.util.Arrays;
 import java.util.UUID;
 import java.util.concurrent.CompletableFuture;
@@ -127,6 +128,7 @@
         response = netty1.sendAsync(invalidEndPoint, subject, "hello world".getBytes());
         response.whenComplete((r, e) -> {
             assertNotNull(e);
+            assertTrue(e instanceof ConnectException);
             latch2.countDown();
         });
         Uninterruptibles.awaitUninterruptibly(latch2);
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
index 45cc529..fc94dd6 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransport.java
@@ -27,7 +27,6 @@
 import org.onosproject.store.cluster.messaging.MessagingService;
 
 import java.net.InetAddress;
-import java.net.InetSocketAddress;
 import java.net.UnknownHostException;
 import java.util.Map;
 
@@ -41,7 +40,6 @@
     private final PartitionId partitionId;
     private final MessagingService messagingService;
     private static final Map<Address, Endpoint> EP_LOOKUP_CACHE = Maps.newConcurrentMap();
-    private static final Map<Endpoint, Address> ADDRESS_LOOKUP_CACHE = Maps.newConcurrentMap();
 
     static final byte MESSAGE = 0x01;
     static final byte CONNECT = 0x02;
@@ -85,22 +83,4 @@
             }
         });
     }
-
-    /**
-     * Maps {@link Endpoint endpoint} to {@link Address address}.
-     * @param endpoint end point
-     * @return address
-     */
-    static Address toAddress(Endpoint endpoint) {
-        return ADDRESS_LOOKUP_CACHE.computeIfAbsent(endpoint, ep -> {
-            try {
-                InetAddress host = InetAddress.getByAddress(endpoint.host().toOctets());
-                int port = endpoint.port();
-                return new Address(new InetSocketAddress(host, port));
-            } catch (UnknownHostException e) {
-                Throwables.propagate(e);
-                return null;
-            }
-        });
-    }
 }
diff --git a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
index eebbf9c..b8596ae 100644
--- a/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
+++ b/core/store/primitives/src/main/java/org/onosproject/store/primitives/impl/CopycatTransportConnection.java
@@ -21,6 +21,7 @@
 import java.io.DataOutputStream;
 import java.io.IOException;
 import java.io.InputStream;
+import java.net.SocketException;
 import java.nio.ByteBuffer;
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
@@ -125,16 +126,7 @@
                                             remoteSubject,
                                             baos.toByteArray(),
                                             context.executor())
-                    .whenComplete((r, e) -> {
-                        Throwable wrappedError = e;
-                        if (e != null) {
-                            Throwable rootCause = Throwables.getRootCause(e);
-                            if (MessagingException.class.isAssignableFrom(rootCause.getClass())) {
-                                wrappedError = new TransportException(e);
-                            }
-                        }
-                        handleResponse(r, wrappedError, future);
-                    });
+                    .whenComplete((response, error) -> handleResponse(response, error, future));
         } catch (SerializationException | IOException e) {
             future.completeExceptionally(e);
         }
@@ -149,9 +141,18 @@
             Throwable error,
             CompletableFuture<T> future) {
         if (error != null) {
-            future.completeExceptionally(error);
+            Throwable rootCause = Throwables.getRootCause(error);
+            if (rootCause instanceof MessagingException || rootCause instanceof SocketException) {
+                future.completeExceptionally(new TransportException(error));
+                if (rootCause instanceof MessagingException.NoRemoteHandler) {
+                    close(rootCause);
+                }
+            } else {
+                future.completeExceptionally(error);
+            }
             return;
         }
+
         checkNotNull(response);
         InputStream input = new ByteArrayInputStream(response);
         try {
@@ -227,7 +228,7 @@
     private CompletableFuture<byte[]> handleClose() {
         CompletableFuture<byte[]> future = new CompletableFuture<>();
         context.executor().execute(() -> {
-            cleanup();
+            close(null);
             ByteBuffer responseBuffer = ByteBuffer.allocate(1);
             responseBuffer.put(SUCCESS);
             future.complete(responseBuffer.array());
@@ -273,7 +274,7 @@
         CompletableFuture<Void> future = new CompletableFuture<>();
         messagingService.sendAndReceive(endpoint, remoteSubject, requestBuffer.array(), context.executor())
                 .whenComplete((payload, error) -> {
-                    cleanup();
+                    close(error);
                     Throwable wrappedError = error;
                     if (error != null) {
                         Throwable rootCause = Throwables.getRootCause(error);
@@ -296,9 +297,12 @@
     /**
      * Cleans up the connection, unregistering handlers registered on the MessagingService.
      */
-    private void cleanup() {
+    private void close(Throwable error) {
         log.debug("Connection {}-{} closed", partitionId, connectionId);
         messagingService.unregisterHandler(localSubject);
+        if (error != null) {
+            exceptionListeners.accept(error);
+        }
         closeListeners.accept(this);
     }
 
diff --git a/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java b/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java
index de76d0d..d62f4af 100644
--- a/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java
+++ b/core/store/primitives/src/test/java/org/onosproject/store/primitives/impl/CopycatTransportTest.java
@@ -43,6 +43,7 @@
 import org.onlab.util.Tools;
 import org.onosproject.cluster.PartitionId;
 import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingException;
 import org.onosproject.store.cluster.messaging.MessagingService;
 
 import static org.junit.Assert.assertEquals;
@@ -60,8 +61,8 @@
     private Endpoint endpoint1 = new Endpoint(IpAddress.valueOf(IP_STRING), 5001);
     private Endpoint endpoint2 = new Endpoint(IpAddress.valueOf(IP_STRING), 5002);
 
-    private MessagingService service1;
-    private MessagingService service2;
+    private TestMessagingService clientService;
+    private TestMessagingService serverService;
 
     private Transport clientTransport;
     private ThreadContext clientContext;
@@ -74,13 +75,13 @@
         Map<Endpoint, TestMessagingService> services = new ConcurrentHashMap<>();
 
         endpoint1 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5001));
-        service1 = new TestMessagingService(endpoint1, services);
-        clientTransport = new CopycatTransport(PartitionId.from(1), service1);
+        clientService = new TestMessagingService(endpoint1, services);
+        clientTransport = new CopycatTransport(PartitionId.from(1), clientService);
         clientContext = new SingleThreadContext("client-test-%d", CatalystSerializers.getSerializer());
 
         endpoint2 = new Endpoint(IpAddress.valueOf("127.0.0.1"), findAvailablePort(5003));
-        service2 = new TestMessagingService(endpoint2, services);
-        serverTransport = new CopycatTransport(PartitionId.from(1), service2);
+        serverService = new TestMessagingService(endpoint2, services);
+        serverTransport = new CopycatTransport(PartitionId.from(1), serverService);
         serverContext = new SingleThreadContext("server-test-%d", CatalystSerializers.getSerializer());
     }
 
@@ -244,6 +245,41 @@
     }
 
     /**
+     * Tests that a client connection is closed on exception.
+     */
+    @Test
+    public void testCopycatClientConnectionCloseOnException() throws Exception {
+        Client client = clientTransport.client();
+        Server server = serverTransport.server();
+
+        CountDownLatch listenLatch = new CountDownLatch(1);
+        CountDownLatch closeLatch = new CountDownLatch(1);
+        CountDownLatch latch = new CountDownLatch(1);
+        serverContext.executor().execute(() -> {
+            server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+                serverContext.checkThread();
+            }).thenRun(listenLatch::countDown);
+        });
+
+        listenLatch.await(5, TimeUnit.SECONDS);
+
+        clientContext.executor().execute(() -> {
+            client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+                clientContext.checkThread();
+                serverService.handlers.clear();
+                connection.onClose(c -> latch.countDown());
+                connection.<ConnectRequest, ConnectResponse>sendAndReceive(ConnectRequest.builder()
+                        .withClientId(UUID.randomUUID().toString())
+                        .build())
+                        .thenAccept(response -> fail());
+            });
+        });
+
+        latch.await(5, TimeUnit.SECONDS);
+        assertEquals(0, latch.getCount());
+    }
+
+    /**
      * Tests closing the server side of a Copycat connection.
      */
     @Test
@@ -286,6 +322,49 @@
     }
 
     /**
+     * Tests that a server connection is closed on exception.
+     */
+    @Test
+    public void testCopycatServerConnectionCloseOnException() throws Exception {
+        Client client = clientTransport.client();
+        Server server = serverTransport.server();
+
+        CountDownLatch latch = new CountDownLatch(1);
+        CountDownLatch listenLatch = new CountDownLatch(1);
+        CountDownLatch connectLatch = new CountDownLatch(1);
+        serverContext.executor().execute(() -> {
+            server.listen(new Address(IP_STRING, endpoint2.port()), connection -> {
+                serverContext.checkThread();
+                serverContext.executor().execute(() -> {
+                    try {
+                        connectLatch.await(5, TimeUnit.SECONDS);
+                    } catch (InterruptedException e) {
+                        fail();
+                    }
+                    clientService.handlers.clear();
+                    connection.onClose(c -> latch.countDown());
+                    connection.<ConnectRequest, ConnectResponse>sendAndReceive(ConnectRequest.builder()
+                            .withClientId("foo")
+                            .build())
+                            .thenAccept(response -> fail());
+                });
+            }).thenRun(listenLatch::countDown);
+        });
+
+        listenLatch.await(5, TimeUnit.SECONDS);
+
+        clientContext.executor().execute(() -> {
+            client.connect(new Address(IP_STRING, endpoint2.port())).thenAccept(connection -> {
+                clientContext.checkThread();
+                connectLatch.countDown();
+            });
+        });
+
+        latch.await(5, TimeUnit.SECONDS);
+        assertEquals(0, latch.getCount());
+    }
+
+    /**
      * Custom implementation of {@code MessagingService} used for testing. Really, this should
      * be mocked but suffices for now.
      */
@@ -304,7 +383,7 @@
         private CompletableFuture<byte[]> handle(Endpoint ep, String type, byte[] message, Executor executor) {
             BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler = handlers.get(type);
             if (handler == null) {
-                return Tools.exceptionalFuture(new IllegalStateException());
+                return Tools.exceptionalFuture(new MessagingException.NoRemoteHandler());
             }
             return handler.apply(ep, message).thenApplyAsync(r -> r, executor);
         }