MessagingService API enchancements

Change-Id: Iabfe15d4f08d7c53bd6575c5d94d0ac9f4e1a38e
diff --git a/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java b/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
index 03f7276..dd82804 100644
--- a/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
+++ b/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
@@ -15,6 +15,12 @@
  */
 package org.onlab.netty;
 
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.RemovalListener;
+import com.google.common.cache.RemovalNotification;
+import com.google.common.util.concurrent.MoreExecutors;
+
 import io.netty.bootstrap.Bootstrap;
 import io.netty.bootstrap.ServerBootstrap;
 import io.netty.buffer.PooledByteBufAllocator;
@@ -35,6 +41,19 @@
 import io.netty.channel.socket.nio.NioServerSocketChannel;
 import io.netty.channel.socket.nio.NioSocketChannel;
 
+import org.apache.commons.pool.KeyedPoolableObjectFactory;
+import org.apache.commons.pool.impl.GenericKeyedObjectPool;
+import org.onlab.util.Tools;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.TrustManagerFactory;
+
 import java.io.FileInputStream;
 import java.io.IOException;
 import java.security.KeyStore;
@@ -47,25 +66,9 @@
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
-import java.util.function.Function;
-
-import javax.net.ssl.KeyManagerFactory;
-import javax.net.ssl.SSLContext;
-import javax.net.ssl.SSLEngine;
-import javax.net.ssl.TrustManagerFactory;
-
-import org.apache.commons.pool.KeyedPoolableObjectFactory;
-import org.apache.commons.pool.impl.GenericKeyedObjectPool;
-import org.onosproject.store.cluster.messaging.Endpoint;
-import org.onosproject.store.cluster.messaging.MessagingService;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import com.google.common.cache.Cache;
-import com.google.common.cache.CacheBuilder;
-import com.google.common.cache.RemovalListener;
-import com.google.common.cache.RemovalNotification;
 
 /**
  * Implementation of MessagingService based on <a href="http://netty.io/">Netty</a> framework.
@@ -81,11 +84,11 @@
     private final AtomicBoolean started = new AtomicBoolean(false);
     private final Map<String, Consumer<InternalMessage>> handlers = new ConcurrentHashMap<>();
     private final AtomicLong messageIdGenerator = new AtomicLong(0);
-    private final Cache<Long, CompletableFuture<byte[]>> responseFutures = CacheBuilder.newBuilder()
+    private final Cache<Long, Callback> callbacks = CacheBuilder.newBuilder()
             .expireAfterWrite(10, TimeUnit.SECONDS)
-            .removalListener(new RemovalListener<Long, CompletableFuture<byte[]>>() {
+            .removalListener(new RemovalListener<Long, Callback>() {
                 @Override
-                public void onRemoval(RemovalNotification<Long, CompletableFuture<byte[]>> entry) {
+                public void onRemoval(RemovalNotification<Long, Callback> entry) {
                     if (entry.wasEvicted()) {
                         entry.getValue().completeExceptionally(new TimeoutException("Timedout waiting for reply"));
                     }
@@ -165,25 +168,29 @@
     }
 
     protected CompletableFuture<Void> sendAsync(Endpoint ep, InternalMessage message) {
+        if (ep.equals(localEp)) {
+            try {
+                dispatchLocally(message);
+            } catch (IOException e) {
+                return Tools.exceptionalFuture(e);
+            }
+            return CompletableFuture.completedFuture(null);
+        }
+
         CompletableFuture<Void> future = new CompletableFuture<>();
         try {
-            if (ep.equals(localEp)) {
-                dispatchLocally(message);
-                future.complete(null);
-            } else {
-                Channel channel = null;
-                try {
-                    channel = channels.borrowObject(ep);
-                    channel.writeAndFlush(message).addListener(channelFuture -> {
-                        if (!channelFuture.isSuccess()) {
-                            future.completeExceptionally(channelFuture.cause());
-                        } else {
-                            future.complete(null);
-                        }
-                    });
-                } finally {
-                    channels.returnObject(ep, channel);
-                }
+            Channel channel = null;
+            try {
+                channel = channels.borrowObject(ep);
+                channel.writeAndFlush(message).addListener(channelFuture -> {
+                    if (!channelFuture.isSuccess()) {
+                        future.completeExceptionally(channelFuture.cause());
+                    } else {
+                        future.complete(null);
+                    }
+                });
+            } finally {
+                channels.returnObject(ep, channel);
             }
         } catch (Exception e) {
             future.completeExceptionally(e);
@@ -193,28 +200,32 @@
 
     @Override
     public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload) {
+        return sendAndReceive(ep, type, payload, MoreExecutors.directExecutor());
+    }
+
+    @Override
+    public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor) {
         CompletableFuture<byte[]> response = new CompletableFuture<>();
+        Callback callback = new Callback(response, executor);
         Long messageId = messageIdGenerator.incrementAndGet();
-        responseFutures.put(messageId, response);
+        callbacks.put(messageId, callback);
         InternalMessage message = new InternalMessage(messageId, localEp, type, payload);
-        try {
-            sendAsync(ep, message);
-        } catch (Exception e) {
-            responseFutures.invalidate(messageId);
-            response.completeExceptionally(e);
-        }
-        return response;
+        return sendAsync(ep, message).whenComplete((r, e) -> {
+            if (e != null) {
+                callbacks.invalidate(messageId);
+            }
+        }).thenCompose(v -> response);
     }
 
     @Override
-    public void registerHandler(String type, Consumer<byte[]> handler, Executor executor) {
-        handlers.put(type, message -> executor.execute(() -> handler.accept(message.payload())));
+    public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) {
+        handlers.put(type, message -> executor.execute(() -> handler.accept(message.sender(), message.payload())));
     }
 
     @Override
-    public void registerHandler(String type, Function<byte[], byte[]> handler, Executor executor) {
+    public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) {
         handlers.put(type, message -> executor.execute(() -> {
-            byte[] responsePayload = handler.apply(message.payload());
+            byte[] responsePayload = handler.apply(message.sender(), message.payload());
             if (responsePayload != null) {
                 InternalMessage response = new InternalMessage(message.id(),
                         localEp,
@@ -230,9 +241,9 @@
     }
 
     @Override
-    public void registerHandler(String type, Function<byte[], CompletableFuture<byte[]>> handler) {
+    public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) {
         handlers.put(type, message -> {
-            handler.apply(message.payload()).whenComplete((result, error) -> {
+            handler.apply(message.sender(), message.payload()).whenComplete((result, error) -> {
                 if (error == null) {
                     InternalMessage response = new InternalMessage(message.id(),
                                                                    localEp,
@@ -435,17 +446,17 @@
         String type = message.type();
         if (REPLY_MESSAGE_TYPE.equals(type)) {
             try {
-                CompletableFuture<byte[]> futureResponse =
-                    responseFutures.getIfPresent(message.id());
-                if (futureResponse != null) {
-                    futureResponse.complete(message.payload());
+                Callback callback =
+                    callbacks.getIfPresent(message.id());
+                if (callback != null) {
+                    callback.complete(message.payload());
                 } else {
                     log.warn("Received a reply for message id:[{}]. "
                             + " from {}. But was unable to locate the"
                             + " request handle", message.id(), message.sender());
                 }
             } finally {
-                responseFutures.invalidate(message.id());
+                callbacks.invalidate(message.id());
             }
             return;
         }
@@ -456,4 +467,22 @@
             log.debug("No handler registered for {}", type);
         }
     }
+
+    private final class Callback {
+        private final CompletableFuture<byte[]> future;
+        private final Executor executor;
+
+        public Callback(CompletableFuture<byte[]> future, Executor executor) {
+            this.future = future;
+            this.executor = executor;
+        }
+
+        public void complete(byte[] value) {
+            executor.execute(() -> future.complete(value));
+        }
+
+        public void completeExceptionally(Throwable error) {
+            executor.execute(() -> future.completeExceptionally(error));
+        }
+    }
 }