MessagingService API enchancements

Change-Id: Iabfe15d4f08d7c53bd6575c5d94d0ac9f4e1a38e
diff --git a/core/api/src/main/java/org/onosproject/store/cluster/messaging/MessagingService.java b/core/api/src/main/java/org/onosproject/store/cluster/messaging/MessagingService.java
index 6ccd483..73f4258 100644
--- a/core/api/src/main/java/org/onosproject/store/cluster/messaging/MessagingService.java
+++ b/core/api/src/main/java/org/onosproject/store/cluster/messaging/MessagingService.java
@@ -17,8 +17,8 @@
 
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Executor;
-import java.util.function.Consumer;
-import java.util.function.Function;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
 
 /**
  * Interface for low level messaging primitives.
@@ -36,7 +36,7 @@
     CompletableFuture<Void> sendAsync(Endpoint ep, String type, byte[] payload);
 
     /**
-     * Sends a message synchronously and waits for a response.
+     * Sends a message asynchronously and expects a response.
      * @param ep end point to send the message to.
      * @param type type of message.
      * @param payload message payload.
@@ -45,12 +45,14 @@
     CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload);
 
     /**
-     * Registers a new message handler for message type.
-     * @param type message type.
-     * @param handler message handler
-     * @param executor executor to use for running message handler logic.
+     * Sends a message synchronously and expects a response.
+     * @param ep end point to send the message to.
+     * @param type type of message.
+     * @param payload message payload.
+     * @param executor executor over which any follow up actions after completion will be executed.
+     * @return a response future
      */
-    void registerHandler(String type, Consumer<byte[]> handler, Executor executor);
+    CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor);
 
     /**
      * Registers a new message handler for message type.
@@ -58,14 +60,22 @@
      * @param handler message handler
      * @param executor executor to use for running message handler logic.
      */
-    void registerHandler(String type, Function<byte[], byte[]> handler, Executor executor);
+    void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor);
+
+    /**
+     * Registers a new message handler for message type.
+     * @param type message type.
+     * @param handler message handler
+     * @param executor executor to use for running message handler logic.
+     */
+    void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor);
 
     /**
      * Registers a new message handler for message type.
      * @param type message type.
      * @param handler message handler
      */
-    void registerHandler(String type, Function<byte[], CompletableFuture<byte[]>> handler);
+    void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler);
 
     /**
      * Unregister current handler, if one exists for message type.
diff --git a/core/store/dist/src/main/java/org/onosproject/store/cluster/impl/DistributedClusterStore.java b/core/store/dist/src/main/java/org/onosproject/store/cluster/impl/DistributedClusterStore.java
index b2ee832..b537517 100644
--- a/core/store/dist/src/main/java/org/onosproject/store/cluster/impl/DistributedClusterStore.java
+++ b/core/store/dist/src/main/java/org/onosproject/store/cluster/impl/DistributedClusterStore.java
@@ -49,7 +49,7 @@
 import java.util.concurrent.Executors;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
-import java.util.function.Consumer;
+import java.util.function.BiConsumer;
 import java.util.stream.Collectors;
 
 import static com.google.common.base.Preconditions.checkNotNull;
@@ -241,9 +241,9 @@
         });
     }
 
-    private class HeartbeatMessageHandler implements Consumer<byte[]> {
+    private class HeartbeatMessageHandler implements BiConsumer<Endpoint, byte[]> {
         @Override
-        public void accept(byte[] message) {
+        public void accept(Endpoint sender, byte[] message) {
             HeartbeatMessage hb = SERIALIZER.decode(message);
             failureDetector.report(hb.source().id());
             hb.knownPeers().forEach(node -> {
diff --git a/core/store/dist/src/main/java/org/onosproject/store/cluster/messaging/impl/ClusterCommunicationManager.java b/core/store/dist/src/main/java/org/onosproject/store/cluster/messaging/impl/ClusterCommunicationManager.java
index 8a237ef..df4ac5c 100644
--- a/core/store/dist/src/main/java/org/onosproject/store/cluster/messaging/impl/ClusterCommunicationManager.java
+++ b/core/store/dist/src/main/java/org/onosproject/store/cluster/messaging/impl/ClusterCommunicationManager.java
@@ -35,10 +35,13 @@
 import org.slf4j.LoggerFactory;
 
 import com.google.common.base.Objects;
+
 import java.util.Set;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.Executor;
 import java.util.concurrent.ExecutorService;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
 import java.util.function.Function;
 import java.util.stream.Collectors;
@@ -210,7 +213,7 @@
                 executor);
     }
 
-    private class InternalClusterMessageHandler implements Function<byte[], byte[]> {
+    private class InternalClusterMessageHandler implements BiFunction<Endpoint, byte[], byte[]> {
         private ClusterMessageHandler handler;
 
         public InternalClusterMessageHandler(ClusterMessageHandler handler) {
@@ -218,14 +221,14 @@
         }
 
         @Override
-        public byte[] apply(byte[] bytes) {
+        public byte[] apply(Endpoint sender, byte[] bytes) {
             ClusterMessage message = ClusterMessage.fromBytes(bytes);
             handler.handle(message);
             return message.response();
         }
     }
 
-    private class InternalMessageResponder<M, R> implements Function<byte[], CompletableFuture<byte[]>> {
+    private class InternalMessageResponder<M, R> implements BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> {
         private final Function<byte[], M> decoder;
         private final Function<R, byte[]> encoder;
         private final Function<M, CompletableFuture<R>> handler;
@@ -239,12 +242,12 @@
         }
 
         @Override
-        public CompletableFuture<byte[]> apply(byte[] bytes) {
+        public CompletableFuture<byte[]> apply(Endpoint sender, byte[] bytes) {
             return handler.apply(decoder.apply(ClusterMessage.fromBytes(bytes).payload())).thenApply(encoder);
         }
     }
 
-    private class InternalMessageConsumer<M> implements Consumer<byte[]> {
+    private class InternalMessageConsumer<M> implements BiConsumer<Endpoint, byte[]> {
         private final Function<byte[], M> decoder;
         private final Consumer<M> consumer;
 
@@ -254,7 +257,7 @@
         }
 
         @Override
-        public void accept(byte[] bytes) {
+        public void accept(Endpoint sender, byte[] bytes) {
             consumer.accept(decoder.apply(ClusterMessage.fromBytes(bytes).payload()));
         }
     }
diff --git a/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java b/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
index 03f7276..dd82804 100644
--- a/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
+++ b/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
@@ -15,6 +15,12 @@
  */
 package org.onlab.netty;
 
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import com.google.common.cache.RemovalListener;
+import com.google.common.cache.RemovalNotification;
+import com.google.common.util.concurrent.MoreExecutors;
+
 import io.netty.bootstrap.Bootstrap;
 import io.netty.bootstrap.ServerBootstrap;
 import io.netty.buffer.PooledByteBufAllocator;
@@ -35,6 +41,19 @@
 import io.netty.channel.socket.nio.NioServerSocketChannel;
 import io.netty.channel.socket.nio.NioSocketChannel;
 
+import org.apache.commons.pool.KeyedPoolableObjectFactory;
+import org.apache.commons.pool.impl.GenericKeyedObjectPool;
+import org.onlab.util.Tools;
+import org.onosproject.store.cluster.messaging.Endpoint;
+import org.onosproject.store.cluster.messaging.MessagingService;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.TrustManagerFactory;
+
 import java.io.FileInputStream;
 import java.io.IOException;
 import java.security.KeyStore;
@@ -47,25 +66,9 @@
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
-import java.util.function.Function;
-
-import javax.net.ssl.KeyManagerFactory;
-import javax.net.ssl.SSLContext;
-import javax.net.ssl.SSLEngine;
-import javax.net.ssl.TrustManagerFactory;
-
-import org.apache.commons.pool.KeyedPoolableObjectFactory;
-import org.apache.commons.pool.impl.GenericKeyedObjectPool;
-import org.onosproject.store.cluster.messaging.Endpoint;
-import org.onosproject.store.cluster.messaging.MessagingService;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
-
-import com.google.common.cache.Cache;
-import com.google.common.cache.CacheBuilder;
-import com.google.common.cache.RemovalListener;
-import com.google.common.cache.RemovalNotification;
 
 /**
  * Implementation of MessagingService based on <a href="http://netty.io/">Netty</a> framework.
@@ -81,11 +84,11 @@
     private final AtomicBoolean started = new AtomicBoolean(false);
     private final Map<String, Consumer<InternalMessage>> handlers = new ConcurrentHashMap<>();
     private final AtomicLong messageIdGenerator = new AtomicLong(0);
-    private final Cache<Long, CompletableFuture<byte[]>> responseFutures = CacheBuilder.newBuilder()
+    private final Cache<Long, Callback> callbacks = CacheBuilder.newBuilder()
             .expireAfterWrite(10, TimeUnit.SECONDS)
-            .removalListener(new RemovalListener<Long, CompletableFuture<byte[]>>() {
+            .removalListener(new RemovalListener<Long, Callback>() {
                 @Override
-                public void onRemoval(RemovalNotification<Long, CompletableFuture<byte[]>> entry) {
+                public void onRemoval(RemovalNotification<Long, Callback> entry) {
                     if (entry.wasEvicted()) {
                         entry.getValue().completeExceptionally(new TimeoutException("Timedout waiting for reply"));
                     }
@@ -165,25 +168,29 @@
     }
 
     protected CompletableFuture<Void> sendAsync(Endpoint ep, InternalMessage message) {
+        if (ep.equals(localEp)) {
+            try {
+                dispatchLocally(message);
+            } catch (IOException e) {
+                return Tools.exceptionalFuture(e);
+            }
+            return CompletableFuture.completedFuture(null);
+        }
+
         CompletableFuture<Void> future = new CompletableFuture<>();
         try {
-            if (ep.equals(localEp)) {
-                dispatchLocally(message);
-                future.complete(null);
-            } else {
-                Channel channel = null;
-                try {
-                    channel = channels.borrowObject(ep);
-                    channel.writeAndFlush(message).addListener(channelFuture -> {
-                        if (!channelFuture.isSuccess()) {
-                            future.completeExceptionally(channelFuture.cause());
-                        } else {
-                            future.complete(null);
-                        }
-                    });
-                } finally {
-                    channels.returnObject(ep, channel);
-                }
+            Channel channel = null;
+            try {
+                channel = channels.borrowObject(ep);
+                channel.writeAndFlush(message).addListener(channelFuture -> {
+                    if (!channelFuture.isSuccess()) {
+                        future.completeExceptionally(channelFuture.cause());
+                    } else {
+                        future.complete(null);
+                    }
+                });
+            } finally {
+                channels.returnObject(ep, channel);
             }
         } catch (Exception e) {
             future.completeExceptionally(e);
@@ -193,28 +200,32 @@
 
     @Override
     public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload) {
+        return sendAndReceive(ep, type, payload, MoreExecutors.directExecutor());
+    }
+
+    @Override
+    public CompletableFuture<byte[]> sendAndReceive(Endpoint ep, String type, byte[] payload, Executor executor) {
         CompletableFuture<byte[]> response = new CompletableFuture<>();
+        Callback callback = new Callback(response, executor);
         Long messageId = messageIdGenerator.incrementAndGet();
-        responseFutures.put(messageId, response);
+        callbacks.put(messageId, callback);
         InternalMessage message = new InternalMessage(messageId, localEp, type, payload);
-        try {
-            sendAsync(ep, message);
-        } catch (Exception e) {
-            responseFutures.invalidate(messageId);
-            response.completeExceptionally(e);
-        }
-        return response;
+        return sendAsync(ep, message).whenComplete((r, e) -> {
+            if (e != null) {
+                callbacks.invalidate(messageId);
+            }
+        }).thenCompose(v -> response);
     }
 
     @Override
-    public void registerHandler(String type, Consumer<byte[]> handler, Executor executor) {
-        handlers.put(type, message -> executor.execute(() -> handler.accept(message.payload())));
+    public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) {
+        handlers.put(type, message -> executor.execute(() -> handler.accept(message.sender(), message.payload())));
     }
 
     @Override
-    public void registerHandler(String type, Function<byte[], byte[]> handler, Executor executor) {
+    public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) {
         handlers.put(type, message -> executor.execute(() -> {
-            byte[] responsePayload = handler.apply(message.payload());
+            byte[] responsePayload = handler.apply(message.sender(), message.payload());
             if (responsePayload != null) {
                 InternalMessage response = new InternalMessage(message.id(),
                         localEp,
@@ -230,9 +241,9 @@
     }
 
     @Override
-    public void registerHandler(String type, Function<byte[], CompletableFuture<byte[]>> handler) {
+    public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) {
         handlers.put(type, message -> {
-            handler.apply(message.payload()).whenComplete((result, error) -> {
+            handler.apply(message.sender(), message.payload()).whenComplete((result, error) -> {
                 if (error == null) {
                     InternalMessage response = new InternalMessage(message.id(),
                                                                    localEp,
@@ -435,17 +446,17 @@
         String type = message.type();
         if (REPLY_MESSAGE_TYPE.equals(type)) {
             try {
-                CompletableFuture<byte[]> futureResponse =
-                    responseFutures.getIfPresent(message.id());
-                if (futureResponse != null) {
-                    futureResponse.complete(message.payload());
+                Callback callback =
+                    callbacks.getIfPresent(message.id());
+                if (callback != null) {
+                    callback.complete(message.payload());
                 } else {
                     log.warn("Received a reply for message id:[{}]. "
                             + " from {}. But was unable to locate the"
                             + " request handle", message.id(), message.sender());
                 }
             } finally {
-                responseFutures.invalidate(message.id());
+                callbacks.invalidate(message.id());
             }
             return;
         }
@@ -456,4 +467,22 @@
             log.debug("No handler registered for {}", type);
         }
     }
+
+    private final class Callback {
+        private final CompletableFuture<byte[]> future;
+        private final Executor executor;
+
+        public Callback(CompletableFuture<byte[]> future, Executor executor) {
+            this.future = future;
+            this.executor = executor;
+        }
+
+        public void complete(byte[] value) {
+            executor.execute(() -> future.complete(value));
+        }
+
+        public void completeExceptionally(Throwable error) {
+            executor.execute(() -> future.completeExceptionally(error));
+        }
+    }
 }
diff --git a/utils/nio/src/main/java/org/onlab/nio/service/IOLoopMessaging.java b/utils/nio/src/main/java/org/onlab/nio/service/IOLoopMessaging.java
index c195d16..45f8e13 100644
--- a/utils/nio/src/main/java/org/onlab/nio/service/IOLoopMessaging.java
+++ b/utils/nio/src/main/java/org/onlab/nio/service/IOLoopMessaging.java
@@ -33,8 +33,9 @@
 import java.util.concurrent.TimeoutException;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
+import java.util.function.BiConsumer;
+import java.util.function.BiFunction;
 import java.util.function.Consumer;
-import java.util.function.Function;
 
 import org.apache.commons.pool.KeyedPoolableObjectFactory;
 import org.apache.commons.pool.impl.GenericKeyedObjectPool;
@@ -50,6 +51,7 @@
 import com.google.common.cache.RemovalListener;
 import com.google.common.cache.RemovalNotification;
 import com.google.common.collect.Lists;
+import com.google.common.util.concurrent.MoreExecutors;
 
 /**
  * MessagingService implementation based on IOLoop.
@@ -86,10 +88,10 @@
 
     private final ConcurrentMap<String, Consumer<DefaultMessage>> handlers = new ConcurrentHashMap<>();
     private final AtomicLong messageIdGenerator = new AtomicLong(0);
-    private final Cache<Long, CompletableFuture<byte[]>> responseFutures = CacheBuilder.newBuilder()
-            .removalListener(new RemovalListener<Long, CompletableFuture<byte[]>>() {
+    private final Cache<Long, Callback> responseFutures = CacheBuilder.newBuilder()
+            .removalListener(new RemovalListener<Long, Callback>() {
                 @Override
-                public void onRemoval(RemovalNotification<Long, CompletableFuture<byte[]>> entry) {
+                public void onRemoval(RemovalNotification<Long, Callback> entry) {
                     if (entry.wasEvicted()) {
                         entry.getValue().completeExceptionally(new TimeoutException("Timedout waiting for reply"));
                     }
@@ -176,29 +178,37 @@
     public CompletableFuture<byte[]> sendAndReceive(
             Endpoint ep,
             String type,
-            byte[] payload) {
+            byte[] payload,
+            Executor executor) {
         CompletableFuture<byte[]> response = new CompletableFuture<>();
+        Callback callback = new Callback(response, executor);
         Long messageId = messageIdGenerator.incrementAndGet();
-        responseFutures.put(messageId, response);
+        responseFutures.put(messageId, callback);
         DefaultMessage message = new DefaultMessage(messageId, localEp, type, payload);
-        try {
-            sendAsync(ep, message);
-        } catch (Exception e) {
-            responseFutures.invalidate(messageId);
-            response.completeExceptionally(e);
-        }
-        return response;
+        return sendAsync(ep, message).whenComplete((r, e) -> {
+            if (e != null) {
+                responseFutures.invalidate(messageId);
+            }
+        }).thenCompose(v -> response);
     }
 
     @Override
-    public void registerHandler(String type, Consumer<byte[]> handler, Executor executor) {
-        handlers.put(type, message -> executor.execute(() -> handler.accept(message.payload())));
+    public CompletableFuture<byte[]> sendAndReceive(
+            Endpoint ep,
+            String type,
+            byte[] payload) {
+        return sendAndReceive(ep, type, payload, MoreExecutors.directExecutor());
     }
 
     @Override
-    public void registerHandler(String type, Function<byte[], byte[]> handler, Executor executor) {
+    public void registerHandler(String type, BiConsumer<Endpoint, byte[]> handler, Executor executor) {
+        handlers.put(type, message -> executor.execute(() -> handler.accept(message.sender(), message.payload())));
+    }
+
+    @Override
+    public void registerHandler(String type, BiFunction<Endpoint, byte[], byte[]> handler, Executor executor) {
         handlers.put(type, message -> executor.execute(() -> {
-            byte[] responsePayload = handler.apply(message.payload());
+            byte[] responsePayload = handler.apply(message.sender(), message.payload());
             if (responsePayload != null) {
                 DefaultMessage response = new DefaultMessage(message.id(),
                         localEp,
@@ -212,9 +222,9 @@
     }
 
     @Override
-    public void registerHandler(String type, Function<byte[], CompletableFuture<byte[]>> handler) {
+    public void registerHandler(String type, BiFunction<Endpoint, byte[], CompletableFuture<byte[]>> handler) {
         handlers.put(type, message -> {
-            handler.apply(message.payload()).whenComplete((result, error) -> {
+            handler.apply(message.sender(), message.payload()).whenComplete((result, error) -> {
                 if (error == null) {
                     DefaultMessage response = new DefaultMessage(message.id(),
                         localEp,
@@ -239,10 +249,10 @@
         String type = message.type();
         if (REPLY_MESSAGE_TYPE.equals(type)) {
             try {
-                CompletableFuture<byte[]> futureResponse =
+                Callback callback =
                         responseFutures.getIfPresent(message.id());
-                if (futureResponse != null) {
-                    futureResponse.complete(message.payload());
+                if (callback != null) {
+                    callback.complete(message.payload());
                 } else {
                     log.warn("Received a reply for message id:[{}]. "
                             + " from {}. But was unable to locate the"
@@ -331,4 +341,23 @@
             return stream.isClosed();
         }
     }
+
+
+    private final class Callback {
+        private final CompletableFuture<byte[]> future;
+        private final Executor executor;
+
+        public Callback(CompletableFuture<byte[]> future, Executor executor) {
+            this.future = future;
+            this.executor = executor;
+        }
+
+        public void complete(byte[] value) {
+            executor.execute(() -> future.complete(value));
+        }
+
+        public void completeExceptionally(Throwable error) {
+            executor.execute(() -> future.completeExceptionally(error));
+        }
+    }
 }