Added locking to channel operations in Grpc controller
Change-Id: Ic6b6542ee1b1c7d582062fa794711dd0f86776bd
diff --git a/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcControllerImpl.java b/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcControllerImpl.java
index fb1e7ee..91b9db0 100644
--- a/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcControllerImpl.java
+++ b/protocols/grpc/ctl/src/main/java/org/onosproject/grpc/ctl/GrpcControllerImpl.java
@@ -16,7 +16,9 @@
package org.onosproject.grpc.ctl;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Maps;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
@@ -51,6 +53,10 @@
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
+
+import static com.google.common.base.Preconditions.checkNotNull;
/**
* Default implementation of the GrpcController.
@@ -71,6 +77,7 @@
private Map<GrpcStreamObserverId, GrpcObserverHandler> observers;
private Map<GrpcChannelId, ManagedChannel> channels;
private Map<GrpcChannelId, ManagedChannelBuilder<?>> channelBuilders;
+ private final Map<GrpcChannelId, Lock> channelLocks = Maps.newConcurrentMap();
@Activate
public void activate() {
@@ -109,20 +116,26 @@
@Override
public ManagedChannel connectChannel(GrpcChannelId channelId, ManagedChannelBuilder<?> channelBuilder)
throws IOException {
+ checkNotNull(channelId);
+ checkNotNull(channelBuilder);
- if (enableMessageLog) {
- channelBuilder.intercept(new InternalLogChannelInterceptor(channelId));
+ Lock lock = channelLocks.computeIfAbsent(channelId, k -> new ReentrantLock());
+ lock.lock();
+
+ try {
+ if (enableMessageLog) {
+ channelBuilder.intercept(new InternalLogChannelInterceptor(channelId));
+ }
+ ManagedChannel channel = channelBuilder.build();
+ // Forced connection not yet implemented. Use workaround...
+ // channel.getState(true);
+ doDummyMessage(channel);
+ channelBuilders.put(channelId, channelBuilder);
+ channels.put(channelId, channel);
+ return channel;
+ } finally {
+ lock.unlock();
}
-
- ManagedChannel channel = channelBuilder.build();
-
- // Forced connection not yet implemented. Use workaround...
- // channel.getState(true);
- doDummyMessage(channel);
-
- channelBuilders.put(channelId, channelBuilder);
- channels.put(channelId, channel);
- return channel;
}
private void doDummyMessage(ManagedChannel channel) throws IOException {
@@ -141,45 +154,63 @@
@Override
public boolean isChannelOpen(GrpcChannelId channelId) {
- if (!channels.containsKey(channelId)) {
- log.warn("Can't check if channel open for unknown channel id {}", channelId);
- return false;
- }
+ checkNotNull(channelId);
+
+ Lock lock = channelLocks.computeIfAbsent(channelId, k -> new ReentrantLock());
+ lock.lock();
try {
- doDummyMessage(channels.get(channelId));
- return true;
- } catch (IOException e) {
- return false;
+ if (!channels.containsKey(channelId)) {
+ log.warn("Can't check if channel open for unknown channel id {}", channelId);
+ return false;
+ }
+ try {
+ doDummyMessage(channels.get(channelId));
+ return true;
+ } catch (IOException e) {
+ return false;
+ }
+ } finally {
+ lock.unlock();
}
}
@Override
public void disconnectChannel(GrpcChannelId channelId) {
- if (!channels.containsKey(channelId)) {
- // Nothing to do.
- return;
- }
- ManagedChannel channel = channels.get(channelId);
+ checkNotNull(channelId);
+
+ Lock lock = channelLocks.computeIfAbsent(channelId, k -> new ReentrantLock());
+ lock.lock();
try {
- channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
- } catch (InterruptedException e) {
- log.warn("Channel {} didn't shut down in time.");
- channel.shutdownNow();
- }
+ if (!channels.containsKey(channelId)) {
+ // Nothing to do.
+ return;
+ }
+ ManagedChannel channel = channels.get(channelId);
- channels.remove(channelId);
- channelBuilders.remove(channelId);
+ try {
+ channel.shutdown().awaitTermination(5, TimeUnit.SECONDS);
+ } catch (InterruptedException e) {
+ log.warn("Channel {} didn't shut down in time.");
+ channel.shutdownNow();
+ }
+
+ channels.remove(channelId);
+ channelBuilders.remove(channelId);
+ } finally {
+ lock.unlock();
+ }
}
@Override
public Map<GrpcChannelId, ManagedChannel> getChannels() {
- return channels;
+ return ImmutableMap.copyOf(channels);
}
@Override
public Collection<ManagedChannel> getChannels(final DeviceId deviceId) {
+ checkNotNull(deviceId);
final Set<ManagedChannel> deviceChannels = new HashSet<>();
channels.forEach((k, v) -> {
if (k.deviceId().equals(deviceId)) {
@@ -192,7 +223,16 @@
@Override
public Optional<ManagedChannel> getChannel(GrpcChannelId channelId) {
- return Optional.ofNullable(channels.get(channelId));
+ checkNotNull(channelId);
+
+ Lock lock = channelLocks.computeIfAbsent(channelId, k -> new ReentrantLock());
+ lock.lock();
+
+ try {
+ return Optional.ofNullable(channels.get(channelId));
+ } finally {
+ lock.unlock();
+ }
}
/**