Allow sharing the same gRPC channel between clients

This change introduces a refactoring of the gRPC protocol subsystem that
allows the creation of a gRPC chanel independently of the client, while
allowing multiple clients to share the same channel (e.g. as in Stratum
where we use 3 clients).

Moreover, we refactor the P4RuntimeClient API to support multiple
P4Runtime-internal device ID using the same client. While before the
client was associated to one of such ID.

Finally, we provide an abstract implementation for gRPC-based driver
behaviors, reducing code duplication in P4Runtime, gNMI and gNOI drivers.

Change-Id: I1a46352bbbef1e0d24042f169ae8ba580202944f
diff --git a/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcChannelControllerImpl.java b/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcChannelControllerImpl.java
index 96a1671..3be7706 100644
--- a/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcChannelControllerImpl.java
+++ b/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcChannelControllerImpl.java
@@ -16,18 +16,16 @@
 
 package org.onosproject.grpc.ctl;
 
-import com.google.common.collect.ImmutableMap;
 import com.google.common.util.concurrent.Striped;
 import io.grpc.ManagedChannel;
 import io.grpc.ManagedChannelBuilder;
-import io.grpc.Status;
-import io.grpc.StatusRuntimeException;
+import io.grpc.netty.GrpcSslContexts;
+import io.grpc.netty.NettyChannelBuilder;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
 import org.onlab.util.Tools;
 import org.onosproject.cfg.ComponentConfigService;
 import org.onosproject.grpc.api.GrpcChannelController;
-import org.onosproject.grpc.api.GrpcChannelId;
-import org.onosproject.grpc.proto.dummy.Dummy;
-import org.onosproject.grpc.proto.dummy.DummyServiceGrpc;
 import org.osgi.service.component.ComponentContext;
 import org.osgi.service.component.annotations.Activate;
 import org.osgi.service.component.annotations.Component;
@@ -38,16 +36,19 @@
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+import javax.net.ssl.SSLException;
+import java.net.URI;
 import java.util.Dictionary;
 import java.util.Map;
 import java.util.Optional;
-import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.locks.Lock;
 
+import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Strings.isNullOrEmpty;
 import static java.lang.String.format;
 import static org.onosproject.grpc.ctl.OsgiPropertyConstants.ENABLE_MESSAGE_LOG;
 import static org.onosproject.grpc.ctl.OsgiPropertyConstants.ENABLE_MESSAGE_LOG_DEFAULT;
@@ -61,6 +62,12 @@
         })
 public class GrpcChannelControllerImpl implements GrpcChannelController {
 
+    private static final String GRPC = "grpc";
+    private static final String GRPCS = "grpcs";
+
+    private static final int DEFAULT_MAX_INBOUND_MSG_SIZE = 256; // Megabytes.
+    private static final int MEGABYTES = 1024 * 1024;
+
     @Reference(cardinality = ReferenceCardinality.MANDATORY)
     protected ComponentConfigService componentConfigService;
 
@@ -72,8 +79,8 @@
 
     private final Logger log = LoggerFactory.getLogger(getClass());
 
-    private Map<GrpcChannelId, ManagedChannel> channels;
-    private Map<GrpcChannelId, GrpcLoggingInterceptor> interceptors;
+    private Map<URI, ManagedChannel> channels;
+    private Map<URI, GrpcLoggingInterceptor> interceptors;
 
     private final Striped<Lock> channelLocks = Striped.lock(30);
 
@@ -109,129 +116,109 @@
     }
 
     @Override
-    public ManagedChannel connectChannel(GrpcChannelId channelId,
-                                         ManagedChannelBuilder<?> channelBuilder) {
-        checkNotNull(channelId);
-        checkNotNull(channelBuilder);
-
-        Lock lock = channelLocks.get(channelId);
-        lock.lock();
-
-        try {
-            if (channels.containsKey(channelId)) {
-                throw new IllegalArgumentException(format(
-                        "A channel with ID '%s' already exists", channelId));
-            }
-
-            final GrpcLoggingInterceptor interceptor = new GrpcLoggingInterceptor(
-                    channelId, enableMessageLog);
-            channelBuilder.intercept(interceptor);
-
-            ManagedChannel channel = channelBuilder.build();
-            // Forced connection API is still experimental. Use workaround...
-            // channel.getState(true);
-            try {
-                doDummyMessage(channel);
-            } catch (StatusRuntimeException e) {
-                interceptor.close();
-                shutdownNowAndWait(channel, channelId);
-                throw e;
-            }
-            // If here, channel is open.
-            channels.put(channelId, channel);
-            interceptors.put(channelId, interceptor);
-            return channel;
-        } finally {
-            lock.unlock();
-        }
-    }
-
-    private void doDummyMessage(ManagedChannel channel) throws StatusRuntimeException {
-        DummyServiceGrpc.DummyServiceBlockingStub dummyStub = DummyServiceGrpc
-                .newBlockingStub(channel)
-                .withDeadlineAfter(CONNECTION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
-        try {
-            //noinspection ResultOfMethodCallIgnored
-            dummyStub.sayHello(Dummy.DummyMessageThatNoOneWouldReallyUse
-                                       .getDefaultInstance());
-        } catch (StatusRuntimeException e) {
-            if (!e.getStatus().equals(Status.UNIMPLEMENTED)) {
-                // UNIMPLEMENTED means that the server received our message but
-                // doesn't know how to handle it. Hence, channel is open.
-                throw e;
-            }
-        }
+    public ManagedChannel create(URI channelUri) {
+        return create(channelUri, makeChannelBuilder(channelUri));
     }
 
     @Override
-    public void disconnectChannel(GrpcChannelId channelId) {
-        checkNotNull(channelId);
+    public ManagedChannel create(URI channelUri, ManagedChannelBuilder<?> channelBuilder) {
+        checkNotNull(channelUri);
+        checkNotNull(channelBuilder);
 
-        Lock lock = channelLocks.get(channelId);
-        lock.lock();
+        channelLocks.get(channelUri).lock();
         try {
-            final ManagedChannel channel = channels.remove(channelId);
-            if (channel != null) {
-                shutdownNowAndWait(channel, channelId);
+            if (channels.containsKey(channelUri)) {
+                throw new IllegalArgumentException(format(
+                        "A channel with ID '%s' already exists", channelUri));
             }
-            final GrpcLoggingInterceptor interceptor = interceptors.remove(channelId);
+
+            log.info("Creating new gRPC channel {}...", channelUri);
+
+            final GrpcLoggingInterceptor interceptor = new GrpcLoggingInterceptor(
+                    channelUri, enableMessageLog);
+            channelBuilder.intercept(interceptor);
+
+            final ManagedChannel channel = channelBuilder.build();
+
+            channels.put(channelUri, channelBuilder.build());
+            interceptors.put(channelUri, interceptor);
+
+            return channel;
+        } finally {
+            channelLocks.get(channelUri).unlock();
+        }
+    }
+
+    private NettyChannelBuilder makeChannelBuilder(URI channelUri) {
+
+        checkArgument(channelUri.getScheme().equals(GRPC)
+                              || channelUri.getScheme().equals(GRPCS),
+                      format("Server URI scheme must be %s or %s", GRPC, GRPCS));
+        checkArgument(!isNullOrEmpty(channelUri.getHost()),
+                      "Server host address should not be empty");
+        checkArgument(channelUri.getPort() > 0 && channelUri.getPort() <= 65535,
+                      "Invalid server port");
+
+        final boolean useTls = channelUri.getScheme().equals(GRPCS);
+
+        final NettyChannelBuilder channelBuilder = NettyChannelBuilder
+                .forAddress(channelUri.getHost(),
+                            channelUri.getPort())
+                .maxInboundMessageSize(DEFAULT_MAX_INBOUND_MSG_SIZE * MEGABYTES);
+
+        if (useTls) {
+            try {
+                // Accept any server certificate; this is insecure and
+                // should not be used in production.
+                final SslContext sslContext = GrpcSslContexts.forClient().trustManager(
+                        InsecureTrustManagerFactory.INSTANCE).build();
+                channelBuilder.sslContext(sslContext).useTransportSecurity();
+            } catch (SSLException e) {
+                log.error("Failed to build SSL context", e);
+                return null;
+            }
+        } else {
+            channelBuilder.usePlaintext();
+        }
+
+        return channelBuilder;
+    }
+
+    @Override
+    public void destroy(URI channelUri) {
+        checkNotNull(channelUri);
+
+        channelLocks.get(channelUri).lock();
+        try {
+            final ManagedChannel channel = channels.remove(channelUri);
+            if (channel != null) {
+                shutdownNowAndWait(channel, channelUri);
+            }
+            final GrpcLoggingInterceptor interceptor = interceptors.remove(channelUri);
             if (interceptor != null) {
                 interceptor.close();
             }
         } finally {
-            lock.unlock();
+            channelLocks.get(channelUri).unlock();
         }
     }
 
-    private void shutdownNowAndWait(ManagedChannel channel, GrpcChannelId channelId) {
+    private void shutdownNowAndWait(ManagedChannel channel, URI channelUri) {
         try {
             if (!channel.shutdownNow()
                     .awaitTermination(5, TimeUnit.SECONDS)) {
-                log.error("Channel '{}' didn't terminate, although we " +
-                                  "triggered a shutdown and waited",
-                          channelId);
+                log.error("Channel {} did not terminate properly",
+                          channelUri);
             }
         } catch (InterruptedException e) {
-            log.warn("Channel {} didn't shutdown in time", channelId);
+            log.warn("Channel {} didn't shutdown in time", channelUri);
             Thread.currentThread().interrupt();
         }
     }
 
     @Override
-    public Map<GrpcChannelId, ManagedChannel> getChannels() {
-        return ImmutableMap.copyOf(channels);
-    }
-
-    @Override
-    public Optional<ManagedChannel> getChannel(GrpcChannelId channelId) {
-        checkNotNull(channelId);
-
-        Lock lock = channelLocks.get(channelId);
-        lock.lock();
-        try {
-            return Optional.ofNullable(channels.get(channelId));
-        } finally {
-            lock.unlock();
-        }
-    }
-
-    @Override
-    public CompletableFuture<Boolean> probeChannel(GrpcChannelId channelId) {
-        final ManagedChannel channel = channels.get(channelId);
-        if (channel == null) {
-            log.warn("Unable to find any channel with ID {}, cannot send probe",
-                     channelId);
-            return CompletableFuture.completedFuture(false);
-        }
-        return CompletableFuture.supplyAsync(() -> {
-            try {
-                doDummyMessage(channel);
-                return true;
-            } catch (StatusRuntimeException e) {
-                log.debug("Probe for {} failed", channelId);
-                log.debug("", e);
-                return false;
-            }
-        });
+    public Optional<ManagedChannel> get(URI channelUri) {
+        checkNotNull(channelUri);
+        return Optional.ofNullable(channels.get(channelUri));
     }
 }