[ONOS-7373] Ensure Netty channels are closed on send exceptions

Change-Id: Ia77ec7857bb5b8dedb508edd8e0977f2eaddac0b
diff --git a/core/store/dist/src/main/java/org/onosproject/store/cluster/messaging/impl/NettyMessagingManager.java b/core/store/dist/src/main/java/org/onosproject/store/cluster/messaging/impl/NettyMessagingManager.java
index 0359bae..20c38e0 100644
--- a/core/store/dist/src/main/java/org/onosproject/store/cluster/messaging/impl/NettyMessagingManager.java
+++ b/core/store/dist/src/main/java/org/onosproject/store/cluster/messaging/impl/NettyMessagingManager.java
@@ -49,8 +49,10 @@
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.function.BiConsumer;
 import java.util.function.BiFunction;
+import java.util.function.Consumer;
 import java.util.function.Function;
 
+import com.google.common.base.Throwables;
 import com.google.common.cache.Cache;
 import com.google.common.cache.CacheBuilder;
 import com.google.common.collect.Lists;
@@ -338,57 +340,6 @@
         return Math.abs(messageType.hashCode() % CHANNEL_POOL_SIZE);
     }
 
-    private CompletableFuture<Channel> getChannel(Endpoint endpoint, String messageType) {
-        List<CompletableFuture<Channel>> channelPool = getChannelPool(endpoint);
-        int offset = getChannelOffset(messageType);
-
-        CompletableFuture<Channel> channelFuture = channelPool.get(offset);
-        if (channelFuture == null || channelFuture.isCompletedExceptionally()) {
-            synchronized (channelPool) {
-                channelFuture = channelPool.get(offset);
-                if (channelFuture == null || channelFuture.isCompletedExceptionally()) {
-                    channelFuture = openChannel(endpoint);
-                    channelPool.set(offset, channelFuture);
-                }
-            }
-        }
-
-        CompletableFuture<Channel> future = new CompletableFuture<>();
-        final CompletableFuture<Channel> finalFuture = channelFuture;
-        finalFuture.whenComplete((channel, error) -> {
-            if (error == null) {
-                if (!channel.isActive()) {
-                    synchronized (channelPool) {
-                        CompletableFuture<Channel> currentFuture = channelPool.get(offset);
-                        if (currentFuture == finalFuture) {
-                            channelPool.set(offset, null);
-                            getChannel(endpoint, messageType).whenComplete((recursiveResult, recursiveError) -> {
-                                if (recursiveError == null) {
-                                    future.complete(recursiveResult);
-                                } else {
-                                    future.completeExceptionally(recursiveError);
-                                }
-                            });
-                        } else {
-                            currentFuture.whenComplete((recursiveResult, recursiveError) -> {
-                                if (recursiveError == null) {
-                                    future.complete(recursiveResult);
-                                } else {
-                                    future.completeExceptionally(recursiveError);
-                                }
-                            });
-                        }
-                    }
-                } else {
-                    future.complete(channel);
-                }
-            } else {
-                future.completeExceptionally(error);
-            }
-        });
-        return future;
-    }
-
     private <T> CompletableFuture<T> executeOnPooledConnection(
             Endpoint endpoint,
             String type,
@@ -400,11 +351,13 @@
     }
 
     private <T> void executeOnPooledConnection(
-            Endpoint endpoint,
-            String type,
-            Function<ClientConnection, CompletableFuture<T>> callback,
-            Executor executor,
-            CompletableFuture<T> future) {
+        Endpoint endpoint,
+        String type,
+        Function<ClientConnection, CompletableFuture<T>> callback,
+        Executor executor,
+        CompletableFuture<T> future) {
+
+        // If the endpoint is the local node, avoid the loopback interface and use the singleton local connection.
         if (endpoint.equals(localEndpoint)) {
             callback.apply(localClientConnection).whenComplete((result, error) -> {
                 if (error == null) {
@@ -416,18 +369,74 @@
             return;
         }
 
-        getChannel(endpoint, type).whenComplete((channel, channelError) -> {
-            if (channelError == null) {
-                ClientConnection connection = clientConnections.computeIfAbsent(channel, RemoteClientConnection::new);
-                callback.apply(connection).whenComplete((result, sendError) -> {
-                    if (sendError == null) {
-                        executor.execute(() -> future.complete(result));
-                    } else {
-                        executor.execute(() -> future.completeExceptionally(sendError));
+        // Get the channel pool and the offset for this message type.
+        List<CompletableFuture<Channel>> channelPool = getChannelPool(endpoint);
+        int offset = getChannelOffset(type);
+
+        // If the channel future is completed exceptionally, open a new channel.
+        CompletableFuture<Channel> channelFuture = channelPool.get(offset);
+        if (channelFuture == null || channelFuture.isCompletedExceptionally()) {
+            synchronized (channelPool) {
+                channelFuture = channelPool.get(offset);
+                if (channelFuture == null || channelFuture.isCompletedExceptionally()) {
+                    channelFuture = openChannel(endpoint);
+                    channelPool.set(offset, channelFuture);
+                }
+            }
+        }
+
+        // Create a consumer with which to complete the send operation on a given channel.
+        final Consumer<Channel> runner = channel -> {
+            ClientConnection connection = clientConnections.computeIfAbsent(channel, RemoteClientConnection::new);
+            callback.apply(connection).whenComplete((result, sendError) -> {
+                if (sendError == null) {
+                    executor.execute(() -> future.complete(result));
+                } else {
+                    // If an exception other than a TimeoutException occurred, close the connection and
+                    // remove the channel from the pool.
+                    Throwable cause = Throwables.getRootCause(sendError);
+                    if (!(cause instanceof TimeoutException) && !(cause instanceof MessagingException)) {
+                        synchronized (channelPool) {
+                            channelPool.set(offset, null);
+                        }
+                        channel.close();
+                        clientConnections.remove(channel);
+                        connection.close();
                     }
-                });
+                    executor.execute(() -> future.completeExceptionally(sendError));
+                }
+            });
+        };
+
+        // Wait for the channel future to be completed. Once it's complete, if the channel is active then
+        // attempt to send the message. Otherwise, if the channel is inactive then attempt to open a new channel.
+        final CompletableFuture<Channel> finalFuture = channelFuture;
+        finalFuture.whenComplete((channel, error) -> {
+            if (error == null) {
+                if (!channel.isActive()) {
+                    final CompletableFuture<Channel> currentFuture;
+                    synchronized (channelPool) {
+                        currentFuture = channelPool.get(offset);
+                        if (currentFuture == finalFuture) {
+                            channelPool.set(offset, null);
+                        }
+                    }
+                    if (currentFuture == finalFuture) {
+                        executeOnPooledConnection(endpoint, type, callback, executor);
+                    } else {
+                        currentFuture.whenComplete((recursiveResult, recursiveError) -> {
+                            if (recursiveError == null) {
+                                runner.accept(recursiveResult);
+                            } else {
+                                future.completeExceptionally(recursiveError);
+                            }
+                        });
+                    }
+                } else {
+                    runner.accept(channel);
+                }
             } else {
-                executor.execute(() -> future.completeExceptionally(channelError));
+                future.completeExceptionally(error);
             }
         });
     }
@@ -652,6 +661,20 @@
             context.close();
         }
 
+        @Override
+        public void channelInactive(ChannelHandlerContext context) throws Exception {
+            RemoteClientConnection clientConnection = clientConnections.remove(context.channel());
+            if (clientConnection != null) {
+                clientConnection.close();
+            }
+
+            RemoteServerConnection serverConnection = serverConnections.remove(context.channel());
+            if (serverConnection != null) {
+                serverConnection.close();
+            }
+            context.close();
+        }
+
         /**
          * Returns true if the given message should be handled.
          *