Allow re-creating clients for the same P4Runtime addr-port

Change-Id: Ib3de10d047f52dd28511e71385773d4b4a9ad74f
diff --git a/protocols/p4runtime/ctl/src/main/java/org/onosproject/p4runtime/ctl/P4RuntimeControllerImpl.java b/protocols/p4runtime/ctl/src/main/java/org/onosproject/p4runtime/ctl/P4RuntimeControllerImpl.java
index 47ebe68..987356b 100644
--- a/protocols/p4runtime/ctl/src/main/java/org/onosproject/p4runtime/ctl/P4RuntimeControllerImpl.java
+++ b/protocols/p4runtime/ctl/src/main/java/org/onosproject/p4runtime/ctl/P4RuntimeControllerImpl.java
@@ -24,6 +24,7 @@
 import io.grpc.ManagedChannelBuilder;
 import io.grpc.NameResolverProvider;
 import io.grpc.internal.DnsNameResolverProvider;
+import io.grpc.netty.NettyChannelBuilder;
 import org.apache.felix.scr.annotations.Activate;
 import org.apache.felix.scr.annotations.Component;
 import org.apache.felix.scr.annotations.Deactivate;
@@ -53,7 +54,6 @@
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 
 import static com.google.common.base.Preconditions.checkNotNull;
-import static java.lang.String.format;
 import static org.slf4j.LoggerFactory.getLogger;
 
 /**
@@ -69,7 +69,8 @@
     private static final int DEVICE_LOCK_EXPIRE_TIME_IN_MIN = 10;
     private final Logger log = getLogger(getClass());
     private final NameResolverProvider nameResolverProvider = new DnsNameResolverProvider();
-    private final Map<DeviceId, P4RuntimeClient> clients = Maps.newHashMap();
+    private final Map<DeviceId, ClientKey> deviceIdToClientKey = Maps.newHashMap();
+    private final Map<ClientKey, P4RuntimeClient> clientKeyToClient = Maps.newHashMap();
     private final Map<DeviceId, GrpcChannelId> channelIds = Maps.newHashMap();
     private final Map<DeviceId, List<ChannelListener>> channelListeners = Maps.newConcurrentMap();
     private final LoadingCache<DeviceId, ReadWriteLock> deviceLocks = CacheBuilder.newBuilder()
@@ -84,10 +85,10 @@
     private AtomicCounter electionIdGenerator;
 
     @Reference(cardinality = ReferenceCardinality.MANDATORY_UNARY)
-    public GrpcController grpcController;
+    private GrpcController grpcController;
 
     @Reference(cardinality = ReferenceCardinality.MANDATORY_UNARY)
-    public StorageService storageService;
+    private StorageService storageService;
 
     @Activate
     public void activate() {
@@ -105,31 +106,44 @@
         log.info("Stopped");
     }
 
-
     @Override
-    public boolean createClient(DeviceId deviceId, long p4DeviceId, ManagedChannelBuilder channelBuilder) {
+    public boolean createClient(DeviceId deviceId, String serverAddr,
+                                int serverPort, long p4DeviceId) {
         checkNotNull(deviceId);
-        checkNotNull(channelBuilder);
+        checkNotNull(serverAddr);
+
+        ClientKey newKey = new ClientKey(deviceId, serverAddr, serverPort, p4DeviceId);
+
+        ManagedChannelBuilder channelBuilder = NettyChannelBuilder
+                .forAddress(serverAddr, serverPort)
+                .usePlaintext(true);
 
         deviceLocks.getUnchecked(deviceId).writeLock().lock();
-        log.info("Creating client for {} (with internal device id {})...", deviceId, p4DeviceId);
+        log.info("Creating client for {} (server={}:{}, p4DeviceId={})...",
+                 deviceId, serverAddr, serverPort, p4DeviceId);
 
         try {
-            if (clients.containsKey(deviceId)) {
-                // TODO might want to consider a more fine-grained check such as same port/p4DeviceId
-                log.warn("A client already exists for {}", deviceId);
-                throw new IllegalStateException(format("A client already exists for %s", deviceId));
+            if (deviceIdToClientKey.containsKey(deviceId)) {
+                final ClientKey existingKey = deviceIdToClientKey.get(deviceId);
+                if (newKey.equals(existingKey)) {
+                    return true;
+                } else {
+                    throw new IllegalStateException(
+                            "A client for the same device ID but different " +
+                                    "server endpoints already exists");
+                }
             } else {
-                return doCreateClient(deviceId, p4DeviceId, channelBuilder);
+                return doCreateClient(newKey, channelBuilder);
             }
         } finally {
             deviceLocks.getUnchecked(deviceId).writeLock().unlock();
         }
     }
 
-    private boolean doCreateClient(DeviceId deviceId, long p4DeviceId, ManagedChannelBuilder channelBuilder) {
+    private boolean doCreateClient(ClientKey clientKey, ManagedChannelBuilder channelBuilder) {
 
-        GrpcChannelId channelId = GrpcChannelId.of(deviceId, "p4runtime");
+        GrpcChannelId channelId = GrpcChannelId.of(clientKey.deviceId(),
+                                                   "p4runtime-" + clientKey.p4DeviceId());
 
         // Channel defaults.
         channelBuilder.nameResolverFactory(nameResolverProvider);
@@ -138,14 +152,17 @@
         try {
             channel = grpcController.connectChannel(channelId, channelBuilder);
         } catch (IOException e) {
-            log.warn("Unable to connect to gRPC server of {}: {}", deviceId, e.getMessage());
+            log.warn("Unable to connect to gRPC server of {}: {}",
+                     clientKey.deviceId(), e.getMessage());
             return false;
         }
 
-        P4RuntimeClient client = new P4RuntimeClientImpl(deviceId, p4DeviceId, channel, this);
+        P4RuntimeClient client = new P4RuntimeClientImpl(
+                clientKey.deviceId(), clientKey.p4DeviceId(), channel, this);
 
-        channelIds.put(deviceId, channelId);
-        clients.put(deviceId, client);
+        channelIds.put(clientKey.deviceId(), channelId);
+        deviceIdToClientKey.put(clientKey.deviceId(), clientKey);
+        clientKeyToClient.put(clientKey, client);
 
         return true;
     }
@@ -156,7 +173,11 @@
         deviceLocks.getUnchecked(deviceId).readLock().lock();
 
         try {
-            return clients.get(deviceId);
+            if (!deviceIdToClientKey.containsKey(deviceId)) {
+                return null;
+            } else {
+                return clientKeyToClient.get(deviceIdToClientKey.get(deviceId));
+            }
         } finally {
             deviceLocks.getUnchecked(deviceId).readLock().unlock();
         }
@@ -168,10 +189,11 @@
         deviceLocks.getUnchecked(deviceId).writeLock().lock();
 
         try {
-            if (clients.containsKey(deviceId)) {
-                clients.get(deviceId).shutdown();
+            if (deviceIdToClientKey.containsKey(deviceId)) {
+                final ClientKey clientKey = deviceIdToClientKey.get(deviceId);
                 grpcController.disconnectChannel(channelIds.get(deviceId));
-                clients.remove(deviceId);
+                clientKeyToClient.remove(clientKey).shutdown();
+                deviceIdToClientKey.remove(deviceId);
                 channelIds.remove(deviceId);
             }
         } finally {
@@ -184,7 +206,7 @@
         deviceLocks.getUnchecked(deviceId).readLock().lock();
 
         try {
-            return clients.containsKey(deviceId);
+            return deviceIdToClientKey.containsKey(deviceId);
         } finally {
             deviceLocks.getUnchecked(deviceId).readLock().unlock();
         }
@@ -196,8 +218,8 @@
         deviceLocks.getUnchecked(deviceId).readLock().lock();
 
         try {
-            if (!clients.containsKey(deviceId)) {
-                log.warn("No client for {}, can't check for reachability", deviceId);
+            if (!deviceIdToClientKey.containsKey(deviceId)) {
+                log.debug("No client for {}, can't check for reachability", deviceId);
                 return false;
             }
 
@@ -239,7 +261,7 @@
         });
     }
 
-    public void postEvent(P4RuntimeEvent event) {
+    void postEvent(P4RuntimeEvent event) {
         if (event.type().equals(P4RuntimeEvent.Type.CHANNEL_EVENT)) {
             DefaultChannelEvent channelError = (DefaultChannelEvent) event.subject();
             DeviceId deviceId = event.subject().deviceId();
@@ -270,4 +292,5 @@
             post(event);
         }
     }
+
 }