Allow sharing the same gRPC channel between clients
This change introduces a refactoring of the gRPC protocol subsystem that
allows the creation of a gRPC chanel independently of the client, while
allowing multiple clients to share the same channel (e.g. as in Stratum
where we use 3 clients).
Moreover, we refactor the P4RuntimeClient API to support multiple
P4Runtime-internal device ID using the same client. While before the
client was associated to one of such ID.
Finally, we provide an abstract implementation for gRPC-based driver
behaviors, reducing code duplication in P4Runtime, gNMI and gNOI drivers.
Change-Id: I1a46352bbbef1e0d24042f169ae8ba580202944f
diff --git a/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcChannelControllerImpl.java b/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcChannelControllerImpl.java
index 96a1671..3be7706 100644
--- a/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcChannelControllerImpl.java
+++ b/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcChannelControllerImpl.java
@@ -16,18 +16,16 @@
package org.onosproject.grpc.ctl;
-import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.Striped;
import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
-import io.grpc.Status;
-import io.grpc.StatusRuntimeException;
+import io.grpc.netty.GrpcSslContexts;
+import io.grpc.netty.NettyChannelBuilder;
+import io.netty.handler.ssl.SslContext;
+import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import org.onlab.util.Tools;
import org.onosproject.cfg.ComponentConfigService;
import org.onosproject.grpc.api.GrpcChannelController;
-import org.onosproject.grpc.api.GrpcChannelId;
-import org.onosproject.grpc.proto.dummy.Dummy;
-import org.onosproject.grpc.proto.dummy.DummyServiceGrpc;
import org.osgi.service.component.ComponentContext;
import org.osgi.service.component.annotations.Activate;
import org.osgi.service.component.annotations.Component;
@@ -38,16 +36,19 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import javax.net.ssl.SSLException;
+import java.net.URI;
import java.util.Dictionary;
import java.util.Map;
import java.util.Optional;
-import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.Lock;
+import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
+import static com.google.common.base.Strings.isNullOrEmpty;
import static java.lang.String.format;
import static org.onosproject.grpc.ctl.OsgiPropertyConstants.ENABLE_MESSAGE_LOG;
import static org.onosproject.grpc.ctl.OsgiPropertyConstants.ENABLE_MESSAGE_LOG_DEFAULT;
@@ -61,6 +62,12 @@
})
public class GrpcChannelControllerImpl implements GrpcChannelController {
+ private static final String GRPC = "grpc";
+ private static final String GRPCS = "grpcs";
+
+ private static final int DEFAULT_MAX_INBOUND_MSG_SIZE = 256; // Megabytes.
+ private static final int MEGABYTES = 1024 * 1024;
+
@Reference(cardinality = ReferenceCardinality.MANDATORY)
protected ComponentConfigService componentConfigService;
@@ -72,8 +79,8 @@
private final Logger log = LoggerFactory.getLogger(getClass());
- private Map<GrpcChannelId, ManagedChannel> channels;
- private Map<GrpcChannelId, GrpcLoggingInterceptor> interceptors;
+ private Map<URI, ManagedChannel> channels;
+ private Map<URI, GrpcLoggingInterceptor> interceptors;
private final Striped<Lock> channelLocks = Striped.lock(30);
@@ -109,129 +116,109 @@
}
@Override
- public ManagedChannel connectChannel(GrpcChannelId channelId,
- ManagedChannelBuilder<?> channelBuilder) {
- checkNotNull(channelId);
- checkNotNull(channelBuilder);
-
- Lock lock = channelLocks.get(channelId);
- lock.lock();
-
- try {
- if (channels.containsKey(channelId)) {
- throw new IllegalArgumentException(format(
- "A channel with ID '%s' already exists", channelId));
- }
-
- final GrpcLoggingInterceptor interceptor = new GrpcLoggingInterceptor(
- channelId, enableMessageLog);
- channelBuilder.intercept(interceptor);
-
- ManagedChannel channel = channelBuilder.build();
- // Forced connection API is still experimental. Use workaround...
- // channel.getState(true);
- try {
- doDummyMessage(channel);
- } catch (StatusRuntimeException e) {
- interceptor.close();
- shutdownNowAndWait(channel, channelId);
- throw e;
- }
- // If here, channel is open.
- channels.put(channelId, channel);
- interceptors.put(channelId, interceptor);
- return channel;
- } finally {
- lock.unlock();
- }
- }
-
- private void doDummyMessage(ManagedChannel channel) throws StatusRuntimeException {
- DummyServiceGrpc.DummyServiceBlockingStub dummyStub = DummyServiceGrpc
- .newBlockingStub(channel)
- .withDeadlineAfter(CONNECTION_TIMEOUT_SECONDS, TimeUnit.SECONDS);
- try {
- //noinspection ResultOfMethodCallIgnored
- dummyStub.sayHello(Dummy.DummyMessageThatNoOneWouldReallyUse
- .getDefaultInstance());
- } catch (StatusRuntimeException e) {
- if (!e.getStatus().equals(Status.UNIMPLEMENTED)) {
- // UNIMPLEMENTED means that the server received our message but
- // doesn't know how to handle it. Hence, channel is open.
- throw e;
- }
- }
+ public ManagedChannel create(URI channelUri) {
+ return create(channelUri, makeChannelBuilder(channelUri));
}
@Override
- public void disconnectChannel(GrpcChannelId channelId) {
- checkNotNull(channelId);
+ public ManagedChannel create(URI channelUri, ManagedChannelBuilder<?> channelBuilder) {
+ checkNotNull(channelUri);
+ checkNotNull(channelBuilder);
- Lock lock = channelLocks.get(channelId);
- lock.lock();
+ channelLocks.get(channelUri).lock();
try {
- final ManagedChannel channel = channels.remove(channelId);
- if (channel != null) {
- shutdownNowAndWait(channel, channelId);
+ if (channels.containsKey(channelUri)) {
+ throw new IllegalArgumentException(format(
+ "A channel with ID '%s' already exists", channelUri));
}
- final GrpcLoggingInterceptor interceptor = interceptors.remove(channelId);
+
+ log.info("Creating new gRPC channel {}...", channelUri);
+
+ final GrpcLoggingInterceptor interceptor = new GrpcLoggingInterceptor(
+ channelUri, enableMessageLog);
+ channelBuilder.intercept(interceptor);
+
+ final ManagedChannel channel = channelBuilder.build();
+
+ channels.put(channelUri, channelBuilder.build());
+ interceptors.put(channelUri, interceptor);
+
+ return channel;
+ } finally {
+ channelLocks.get(channelUri).unlock();
+ }
+ }
+
+ private NettyChannelBuilder makeChannelBuilder(URI channelUri) {
+
+ checkArgument(channelUri.getScheme().equals(GRPC)
+ || channelUri.getScheme().equals(GRPCS),
+ format("Server URI scheme must be %s or %s", GRPC, GRPCS));
+ checkArgument(!isNullOrEmpty(channelUri.getHost()),
+ "Server host address should not be empty");
+ checkArgument(channelUri.getPort() > 0 && channelUri.getPort() <= 65535,
+ "Invalid server port");
+
+ final boolean useTls = channelUri.getScheme().equals(GRPCS);
+
+ final NettyChannelBuilder channelBuilder = NettyChannelBuilder
+ .forAddress(channelUri.getHost(),
+ channelUri.getPort())
+ .maxInboundMessageSize(DEFAULT_MAX_INBOUND_MSG_SIZE * MEGABYTES);
+
+ if (useTls) {
+ try {
+ // Accept any server certificate; this is insecure and
+ // should not be used in production.
+ final SslContext sslContext = GrpcSslContexts.forClient().trustManager(
+ InsecureTrustManagerFactory.INSTANCE).build();
+ channelBuilder.sslContext(sslContext).useTransportSecurity();
+ } catch (SSLException e) {
+ log.error("Failed to build SSL context", e);
+ return null;
+ }
+ } else {
+ channelBuilder.usePlaintext();
+ }
+
+ return channelBuilder;
+ }
+
+ @Override
+ public void destroy(URI channelUri) {
+ checkNotNull(channelUri);
+
+ channelLocks.get(channelUri).lock();
+ try {
+ final ManagedChannel channel = channels.remove(channelUri);
+ if (channel != null) {
+ shutdownNowAndWait(channel, channelUri);
+ }
+ final GrpcLoggingInterceptor interceptor = interceptors.remove(channelUri);
if (interceptor != null) {
interceptor.close();
}
} finally {
- lock.unlock();
+ channelLocks.get(channelUri).unlock();
}
}
- private void shutdownNowAndWait(ManagedChannel channel, GrpcChannelId channelId) {
+ private void shutdownNowAndWait(ManagedChannel channel, URI channelUri) {
try {
if (!channel.shutdownNow()
.awaitTermination(5, TimeUnit.SECONDS)) {
- log.error("Channel '{}' didn't terminate, although we " +
- "triggered a shutdown and waited",
- channelId);
+ log.error("Channel {} did not terminate properly",
+ channelUri);
}
} catch (InterruptedException e) {
- log.warn("Channel {} didn't shutdown in time", channelId);
+ log.warn("Channel {} didn't shutdown in time", channelUri);
Thread.currentThread().interrupt();
}
}
@Override
- public Map<GrpcChannelId, ManagedChannel> getChannels() {
- return ImmutableMap.copyOf(channels);
- }
-
- @Override
- public Optional<ManagedChannel> getChannel(GrpcChannelId channelId) {
- checkNotNull(channelId);
-
- Lock lock = channelLocks.get(channelId);
- lock.lock();
- try {
- return Optional.ofNullable(channels.get(channelId));
- } finally {
- lock.unlock();
- }
- }
-
- @Override
- public CompletableFuture<Boolean> probeChannel(GrpcChannelId channelId) {
- final ManagedChannel channel = channels.get(channelId);
- if (channel == null) {
- log.warn("Unable to find any channel with ID {}, cannot send probe",
- channelId);
- return CompletableFuture.completedFuture(false);
- }
- return CompletableFuture.supplyAsync(() -> {
- try {
- doDummyMessage(channel);
- return true;
- } catch (StatusRuntimeException e) {
- log.debug("Probe for {} failed", channelId);
- log.debug("", e);
- return false;
- }
- });
+ public Optional<ManagedChannel> get(URI channelUri) {
+ checkNotNull(channelUri);
+ return Optional.ofNullable(channels.get(channelUri));
}
}