ONOS-4118 Added synchronization and resiliency to Bmv2ThriftClient

Due to the multi-threaded nature of drivers, calls to a Bmv2ThriftClient
could result in a race condition if not properly synchronized. Also,
once open, transport session might close due to several reasons. Now the
client calls are synchronized and automatically wrapped in a try/catch
that tries to re-open the session for fixed number of times before
giving up.

Change-Id: I5dcdd5a6304406dc6d9d3a0ccf7f16cdbf3b9573
diff --git a/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2FlowRuleDriver.java b/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2FlowRuleDriver.java
index 3c6acf5..2947b5f 100644
--- a/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2FlowRuleDriver.java
+++ b/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2FlowRuleDriver.java
@@ -20,6 +20,7 @@
 import com.google.common.collect.Maps;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.commons.lang3.tuple.Triple;
+import org.onosproject.bmv2.api.runtime.Bmv2Client;
 import org.onosproject.bmv2.api.runtime.Bmv2MatchKey;
 import org.onosproject.bmv2.api.runtime.Bmv2RuntimeException;
 import org.onosproject.bmv2.api.runtime.Bmv2TableEntry;
@@ -84,7 +85,7 @@
 
         DeviceId deviceId = handler().data().deviceId();
 
-        Bmv2ThriftClient deviceClient;
+        Bmv2Client deviceClient;
         try {
             deviceClient = Bmv2ThriftClient.of(deviceId);
         } catch (Bmv2RuntimeException e) {
diff --git a/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2PortGetterDriver.java b/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2PortGetterDriver.java
index 927b77d..e1da42b 100644
--- a/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2PortGetterDriver.java
+++ b/drivers/bmv2/src/main/java/org/onosproject/drivers/bmv2/Bmv2PortGetterDriver.java
@@ -18,6 +18,7 @@
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
+import org.onosproject.bmv2.api.runtime.Bmv2Client;
 import org.onosproject.bmv2.api.runtime.Bmv2RuntimeException;
 import org.onosproject.bmv2.ctl.Bmv2ThriftClient;
 import org.onosproject.net.DefaultAnnotations;
@@ -41,7 +42,7 @@
 
     @Override
     public List<PortDescription> getPorts() {
-        Bmv2ThriftClient deviceClient;
+        Bmv2Client deviceClient;
         try {
             deviceClient = Bmv2ThriftClient.of(handler().data().deviceId());
         } catch (Bmv2RuntimeException e) {
diff --git a/protocols/bmv2/src/main/java/org/onosproject/bmv2/api/runtime/Bmv2Client.java b/protocols/bmv2/src/main/java/org/onosproject/bmv2/api/runtime/Bmv2Client.java
new file mode 100644
index 0000000..252f59b
--- /dev/null
+++ b/protocols/bmv2/src/main/java/org/onosproject/bmv2/api/runtime/Bmv2Client.java
@@ -0,0 +1,89 @@
+/*
+ * Copyright 2016-present Open Networking Laboratory
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.onosproject.bmv2.api.runtime;
+
+import java.util.Collection;
+
+/**
+ * RPC client to control a BMv2 device.
+ */
+public interface Bmv2Client {
+    /**
+     * Adds a new table entry.
+     *
+     * @param entry a table entry value
+     * @return table-specific entry ID
+     * @throws Bmv2RuntimeException if any error occurs
+     */
+    long addTableEntry(Bmv2TableEntry entry) throws Bmv2RuntimeException;
+
+    /**
+     * Modifies a currently installed entry by updating its action.
+     *
+     * @param tableName string value of table name
+     * @param entryId   long value of entry ID
+     * @param action    an action value
+     * @throws Bmv2RuntimeException if any error occurs
+     */
+    void modifyTableEntry(String tableName,
+                          long entryId, Bmv2Action action)
+            throws Bmv2RuntimeException;
+
+    /**
+     * Deletes currently installed entry.
+     *
+     * @param tableName string value of table name
+     * @param entryId   long value of entry ID
+     * @throws Bmv2RuntimeException if any error occurs
+     */
+    void deleteTableEntry(String tableName,
+                          long entryId) throws Bmv2RuntimeException;
+
+    /**
+     * Sets table default action.
+     *
+     * @param tableName string value of table name
+     * @param action    an action value
+     * @throws Bmv2RuntimeException if any error occurs
+     */
+    void setTableDefaultAction(String tableName, Bmv2Action action)
+            throws Bmv2RuntimeException;
+
+    /**
+     * Returns information of the ports currently configured in the switch.
+     *
+     * @return collection of port information
+     * @throws Bmv2RuntimeException if any error occurs
+     */
+    Collection<Bmv2PortInfo> getPortsInfo() throws Bmv2RuntimeException;
+
+    /**
+     * Return a string representation of a table content.
+     *
+     * @param tableName string value of table name
+     * @return table string dump
+     * @throws Bmv2RuntimeException if any error occurs
+     */
+    String dumpTable(String tableName) throws Bmv2RuntimeException;
+
+    /**
+     * Reset the state of the switch (e.g. delete all entries, etc.).
+     *
+     * @throws Bmv2RuntimeException if any error occurs
+     */
+    void resetState() throws Bmv2RuntimeException;
+}
diff --git a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java
index eb6687a..f1a86fc 100644
--- a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java
+++ b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java
@@ -31,13 +31,13 @@
 import org.apache.thrift.transport.TSocket;
 import org.apache.thrift.transport.TTransport;
 import org.apache.thrift.transport.TTransportException;
-import org.onlab.util.ImmutableByteSequence;
 import org.onosproject.bmv2.api.runtime.Bmv2Action;
+import org.onosproject.bmv2.api.runtime.Bmv2Client;
 import org.onosproject.bmv2.api.runtime.Bmv2ExactMatchParam;
-import org.onosproject.bmv2.api.runtime.Bmv2RuntimeException;
 import org.onosproject.bmv2.api.runtime.Bmv2LpmMatchParam;
 import org.onosproject.bmv2.api.runtime.Bmv2MatchKey;
 import org.onosproject.bmv2.api.runtime.Bmv2PortInfo;
+import org.onosproject.bmv2.api.runtime.Bmv2RuntimeException;
 import org.onosproject.bmv2.api.runtime.Bmv2TableEntry;
 import org.onosproject.bmv2.api.runtime.Bmv2TernaryMatchParam;
 import org.onosproject.bmv2.api.runtime.Bmv2ValidMatchParam;
@@ -51,6 +51,8 @@
 import org.p4.bmv2.thrift.BmMatchParamValid;
 import org.p4.bmv2.thrift.DevMgrPortInfo;
 import org.p4.bmv2.thrift.Standard;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
 
 import java.nio.ByteBuffer;
 import java.util.Collection;
@@ -60,38 +62,43 @@
 import java.util.stream.Collectors;
 
 import static com.google.common.base.Preconditions.checkNotNull;
+import static org.onosproject.bmv2.ctl.SafeThriftClient.Options;
 
 /**
  * Implementation of a Thrift client to control the Bmv2 switch.
  */
-public final class Bmv2ThriftClient {
-    /*
-    FIXME: derive context_id from device id
-    Using different context id values should serve to control different
-    switches responding to the same IP address and port
-    */
+public final class Bmv2ThriftClient implements Bmv2Client {
+
+    private static final Logger LOG =
+            LoggerFactory.getLogger(Bmv2ThriftClient.class);
+
+    // FIXME: make context_id arbitrary for each call
+    // See: https://github.com/p4lang/behavioral-model/blob/master/modules/bm_sim/include/bm_sim/context.h
     private static final int CONTEXT_ID = 0;
-    /*
-    Static transport/client cache:
-        - avoids opening a new transport session when there's one already open
-        - close the connection after a predefined timeout of 5 seconds
-     */
-    private static LoadingCache<DeviceId, Bmv2ThriftClient>
-            clientCache = CacheBuilder.newBuilder()
-            .expireAfterAccess(5, TimeUnit.SECONDS)
+    // Seconds after a client is expired (and connection closed) in the cache.
+    private static final int CLIENT_CACHE_TIMEOUT = 60;
+    // Number of connection retries after a network error.
+    private static final int NUM_CONNECTION_RETRIES = 10;
+    // Time between retries in milliseconds.
+    private static final int TIME_BETWEEN_RETRIES = 200;
+
+    // Static client cache where clients are removed after a predefined timeout.
+    private static final LoadingCache<DeviceId, Bmv2ThriftClient>
+            CLIENT_CACHE = CacheBuilder.newBuilder()
+            .expireAfterAccess(CLIENT_CACHE_TIMEOUT, TimeUnit.SECONDS)
             .removalListener(new ClientRemovalListener())
             .build(new ClientLoader());
     private final Standard.Iface stdClient;
     private final TTransport transport;
+    private final DeviceId deviceId;
 
     // ban constructor
-    private Bmv2ThriftClient(TTransport transport, Standard.Iface stdClient) {
+    private Bmv2ThriftClient(DeviceId deviceId, TTransport transport, Standard.Iface stdClient) {
+        this.deviceId = deviceId;
         this.transport = transport;
         this.stdClient = stdClient;
-    }
 
-    private void closeTransport() {
-        this.transport.close();
+        LOG.debug("New client created! > deviceId={}", deviceId);
     }
 
     /**
@@ -104,8 +111,10 @@
     public static Bmv2ThriftClient of(DeviceId deviceId) throws Bmv2RuntimeException {
         try {
             checkNotNull(deviceId, "deviceId cannot be null");
-            return clientCache.get(deviceId);
+            LOG.debug("Getting a client from cache... > deviceId{}", deviceId);
+            return CLIENT_CACHE.get(deviceId);
         } catch (ExecutionException e) {
+            LOG.debug("Exception while getting a client from cache: {} > ", e, deviceId);
             throw new Bmv2RuntimeException(e.getMessage(), e.getCause());
         }
     }
@@ -120,9 +129,13 @@
     public static boolean ping(DeviceId deviceId) {
         // poll ports status as workaround to assess device reachability
         try {
-            of(deviceId).stdClient.bm_dev_mgr_show_ports();
+            LOG.debug("Pinging device... > deviceId={}", deviceId);
+            Bmv2ThriftClient client = of(deviceId);
+            client.stdClient.bm_dev_mgr_show_ports();
+            LOG.debug("Device reachable! > deviceId={}", deviceId);
             return true;
         } catch (TException | Bmv2RuntimeException e) {
+            LOG.debug("Device NOT reachable! > deviceId={}", deviceId);
             return false;
         }
     }
@@ -156,32 +169,34 @@
     private static List<BmMatchParam> buildMatchParamsList(Bmv2MatchKey matchKey) {
         List<BmMatchParam> paramsList = Lists.newArrayList();
         matchKey.matchParams().forEach(x -> {
+            ByteBuffer value;
+            ByteBuffer mask;
             switch (x.type()) {
                 case EXACT:
+                    value = ByteBuffer.wrap(((Bmv2ExactMatchParam) x).value().asArray());
                     paramsList.add(
                             new BmMatchParam(BmMatchParamType.EXACT)
-                                    .setExact(new BmMatchParamExact(
-                                            ((Bmv2ExactMatchParam) x).value().asReadOnlyBuffer())));
+                                    .setExact(new BmMatchParamExact(value)));
                     break;
                 case TERNARY:
+                    value = ByteBuffer.wrap(((Bmv2TernaryMatchParam) x).value().asArray());
+                    mask = ByteBuffer.wrap(((Bmv2TernaryMatchParam) x).mask().asArray());
                     paramsList.add(
                             new BmMatchParam(BmMatchParamType.TERNARY)
-                                    .setTernary(new BmMatchParamTernary(
-                                            ((Bmv2TernaryMatchParam) x).value().asReadOnlyBuffer(),
-                                            ((Bmv2TernaryMatchParam) x).mask().asReadOnlyBuffer())));
+                                    .setTernary(new BmMatchParamTernary(value, mask)));
                     break;
                 case LPM:
+                    value = ByteBuffer.wrap(((Bmv2LpmMatchParam) x).value().asArray());
+                    int prefixLength = ((Bmv2LpmMatchParam) x).prefixLength();
                     paramsList.add(
                             new BmMatchParam(BmMatchParamType.LPM)
-                                    .setLpm(new BmMatchParamLPM(
-                                            ((Bmv2LpmMatchParam) x).value().asReadOnlyBuffer(),
-                                            ((Bmv2LpmMatchParam) x).prefixLength())));
+                                    .setLpm(new BmMatchParamLPM(value, prefixLength)));
                     break;
                 case VALID:
+                    boolean flag = ((Bmv2ValidMatchParam) x).flag();
                     paramsList.add(
                             new BmMatchParam(BmMatchParamType.VALID)
-                                    .setValid(new BmMatchParamValid(
-                                            ((Bmv2ValidMatchParam) x).flag())));
+                                    .setValid(new BmMatchParamValid(flag)));
                     break;
                 default:
                     // should never be here
@@ -198,21 +213,26 @@
      * @return list of ByteBuffers
      */
     private static List<ByteBuffer> buildActionParamsList(Bmv2Action action) {
-        return action.parameters()
-                .stream()
-                .map(ImmutableByteSequence::asReadOnlyBuffer)
-                .collect(Collectors.toList());
+        List<ByteBuffer> buffers = Lists.newArrayList();
+        action.parameters().forEach(p -> buffers.add(ByteBuffer.wrap(p.asArray())));
+        return buffers;
     }
 
-    /**
-     * Adds a new table entry.
-     *
-     * @param entry a table entry value
-     * @return table-specific entry ID
-     * @throws Bmv2RuntimeException if any error occurs
-     */
+    private void closeTransport() {
+        LOG.debug("Closing transport session... > deviceId={}", deviceId);
+        if (this.transport.isOpen()) {
+            this.transport.close();
+            LOG.debug("Transport session closed! > deviceId={}", deviceId);
+        } else {
+            LOG.debug("Transport session was already closed! deviceId={}", deviceId);
+        }
+    }
+
+    @Override
     public final long addTableEntry(Bmv2TableEntry entry) throws Bmv2RuntimeException {
 
+        LOG.debug("Adding table entry... > deviceId={}, entry={}", deviceId, entry);
+
         long entryId = -1;
 
         try {
@@ -237,34 +257,33 @@
                         CONTEXT_ID, entry.tableName(), entryId, msTimeout);
             }
 
+            LOG.debug("Table entry added! > deviceId={}, entryId={}/{}", deviceId, entry.tableName(), entryId);
+
             return entryId;
 
         } catch (TException e) {
+            LOG.debug("Exception while adding table entry: {} > deviceId={}, tableName={}",
+                      e, deviceId, entry.tableName());
             if (entryId != -1) {
+                // entry is in inconsistent state (unable to add timeout), remove it
                 try {
-                    stdClient.bm_mt_delete_entry(
-                            CONTEXT_ID, entry.tableName(), entryId);
-                } catch (TException e1) {
-                    // this should never happen as we know the entry is there
-                    throw new Bmv2RuntimeException(e1.getMessage(), e1);
+                    deleteTableEntry(entry.tableName(), entryId);
+                } catch (Bmv2RuntimeException e1) {
+                    LOG.debug("Unable to remove failed table entry: {} > deviceId={}, tableName={}",
+                              e1, deviceId, entry.tableName());
                 }
             }
             throw new Bmv2RuntimeException(e.getMessage(), e);
         }
     }
 
-    /**
-     * Modifies a currently installed entry by updating its action.
-     *
-     * @param tableName string value of table name
-     * @param entryId   long value of entry ID
-     * @param action    an action value
-     * @throws Bmv2RuntimeException if any error occurs
-     */
+    @Override
     public final void modifyTableEntry(String tableName,
                                        long entryId, Bmv2Action action)
             throws Bmv2RuntimeException {
 
+        LOG.debug("Modifying table entry... > deviceId={}, entryId={}/{}", deviceId, tableName, entryId);
+
         try {
             stdClient.bm_mt_modify_entry(
                     CONTEXT_ID,
@@ -272,57 +291,55 @@
                     entryId,
                     action.name(),
                     buildActionParamsList(action));
+            LOG.debug("Table entry modified! > deviceId={}, entryId={}/{}", deviceId, tableName, entryId);
         } catch (TException e) {
+            LOG.debug("Exception while modifying table entry: {} > deviceId={}, entryId={}/{}",
+                      e, deviceId, tableName, entryId);
             throw new Bmv2RuntimeException(e.getMessage(), e);
         }
     }
 
-    /**
-     * Deletes currently installed entry.
-     *
-     * @param tableName string value of table name
-     * @param entryId   long value of entry ID
-     * @throws Bmv2RuntimeException if any error occurs
-     */
+    @Override
     public final void deleteTableEntry(String tableName,
                                        long entryId) throws Bmv2RuntimeException {
 
+        LOG.debug("Deleting table entry... > deviceId={}, entryId={}/{}", deviceId, tableName, entryId);
+
         try {
             stdClient.bm_mt_delete_entry(CONTEXT_ID, tableName, entryId);
+            LOG.debug("Table entry deleted! > deviceId={}, entryId={}/{}", deviceId, tableName, entryId);
         } catch (TException e) {
+            LOG.debug("Exception while deleting table entry: {} > deviceId={}, entryId={}/{}",
+                      e, deviceId, tableName, entryId);
             throw new Bmv2RuntimeException(e.getMessage(), e);
         }
     }
 
-    /**
-     * Sets table default action.
-     *
-     * @param tableName string value of table name
-     * @param action    an action value
-     * @throws Bmv2RuntimeException if any error occurs
-     */
+    @Override
     public final void setTableDefaultAction(String tableName, Bmv2Action action)
             throws Bmv2RuntimeException {
 
+        LOG.debug("Setting table default... > deviceId={}, tableName={}, action={}", deviceId, tableName, action);
+
         try {
             stdClient.bm_mt_set_default_action(
                     CONTEXT_ID,
                     tableName,
                     action.name(),
                     buildActionParamsList(action));
+            LOG.debug("Table default set! > deviceId={}, tableName={}, action={}", deviceId, tableName, action);
         } catch (TException e) {
+            LOG.debug("Exception while setting table default : {} > deviceId={}, tableName={}, action={}",
+                      e, deviceId, tableName, action);
             throw new Bmv2RuntimeException(e.getMessage(), e);
         }
     }
 
-    /**
-     * Returns information of the ports currently configured in the switch.
-     *
-     * @return collection of port information
-     * @throws Bmv2RuntimeException if any error occurs
-     */
+    @Override
     public Collection<Bmv2PortInfo> getPortsInfo() throws Bmv2RuntimeException {
 
+        LOG.debug("Retrieving port info... > deviceId={}", deviceId);
+
         try {
             List<DevMgrPortInfo> portInfos = stdClient.bm_dev_mgr_show_ports();
 
@@ -333,39 +350,42 @@
                             .map(Bmv2PortInfo::new)
                             .collect(Collectors.toList()));
 
+            LOG.debug("Port info retrieved! > deviceId={}, portInfos={}", deviceId, bmv2PortInfos);
+
             return bmv2PortInfos;
 
         } catch (TException e) {
+            LOG.debug("Exception while retrieving port info: {} > deviceId={}", e, deviceId);
             throw new Bmv2RuntimeException(e.getMessage(), e);
         }
     }
 
-    /**
-     * Return a string representation of a table content.
-     *
-     * @param tableName string value of table name
-     * @return table string dump
-     * @throws Bmv2RuntimeException if any error occurs
-     */
+    @Override
     public String dumpTable(String tableName) throws Bmv2RuntimeException {
 
+        LOG.debug("Retrieving table dump... > deviceId={}, tableName={}", deviceId, tableName);
+
         try {
-            return stdClient.bm_dump_table(CONTEXT_ID, tableName);
+            String dump = stdClient.bm_dump_table(CONTEXT_ID, tableName);
+            LOG.debug("Table dump retrieved! > deviceId={}, tableName={}", deviceId, tableName);
+            return dump;
         } catch (TException e) {
+            LOG.debug("Exception while retrieving table dump: {} > deviceId={}, tableName={}",
+                      e, deviceId, tableName);
             throw new Bmv2RuntimeException(e.getMessage(), e);
         }
     }
 
-    /**
-     * Reset the state of the switch (e.g. delete all entries, etc.).
-     *
-     * @throws Bmv2RuntimeException if any error occurs
-     */
+    @Override
     public void resetState() throws Bmv2RuntimeException {
 
+        LOG.debug("Resetting device state... > deviceId={}", deviceId);
+
         try {
             stdClient.bm_reset_state();
+            LOG.debug("Device state reset! > deviceId={}", deviceId);
         } catch (TException e) {
+            LOG.debug("Exception while resetting device state: {} > deviceId={}", e, deviceId);
             throw new Bmv2RuntimeException(e.getMessage(), e);
         }
     }
@@ -376,20 +396,26 @@
     private static class ClientLoader
             extends CacheLoader<DeviceId, Bmv2ThriftClient> {
 
+        // Connection retries options: max 10 retries each 200 ms
+        private static final Options RECONN_OPTIONS = new Options(NUM_CONNECTION_RETRIES, TIME_BETWEEN_RETRIES);
+
         @Override
         public Bmv2ThriftClient load(DeviceId deviceId)
                 throws TTransportException {
+            LOG.debug("Creating new client in cache... > deviceId={}", deviceId);
             Pair<String, Integer> info = parseDeviceId(deviceId);
             //make the expensive call
             TTransport transport = new TSocket(
                     info.getLeft(), info.getRight());
             TProtocol protocol = new TBinaryProtocol(transport);
-            Standard.Iface stdClient = new Standard.Client(
+            Standard.Client stdClient = new Standard.Client(
                     new TMultiplexedProtocol(protocol, "standard"));
+            // Wrap the client so to automatically have synchronization and resiliency to connectivity problems
+            Standard.Iface reconnStdIface = SafeThriftClient.wrap(stdClient,
+                                                                  Standard.Iface.class,
+                                                                  RECONN_OPTIONS);
 
-            transport.open();
-
-            return new Bmv2ThriftClient(transport, stdClient);
+            return new Bmv2ThriftClient(deviceId, transport, reconnStdIface);
         }
     }
 
@@ -403,6 +429,7 @@
         public void onRemoval(
                 RemovalNotification<DeviceId, Bmv2ThriftClient> notification) {
             // close the transport connection
+            LOG.debug("Removing client from cache... > deviceId={}", notification.getKey());
             notification.getValue().closeTransport();
         }
     }
diff --git a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java
new file mode 100644
index 0000000..bbe0546
--- /dev/null
+++ b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java
@@ -0,0 +1,247 @@
+/*
+ * Copyright 2016-present Open Networking Laboratory
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Most of the code of this class was copied from:
+ * http://liveramp.com/engineering/reconnecting-thrift-client/
+ */
+
+package org.onosproject.bmv2.ctl;
+
+import com.google.common.collect.ImmutableSet;
+import org.apache.thrift.TServiceClient;
+import org.apache.thrift.transport.TTransport;
+import org.apache.thrift.transport.TTransportException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.lang.reflect.InvocationHandler;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.lang.reflect.Proxy;
+import java.util.Set;
+
+/**
+ * Thrift client wrapper that attempts a few reconnects before giving up a method call execution. It al provides
+ * synchronization between calls (automatically serialize multiple calls to the same client from different threads).
+ */
+public final class SafeThriftClient {
+
+    private static final Logger LOG = LoggerFactory.getLogger(SafeThriftClient.class);
+
+    /**
+     * List of causes which suggest a restart might fix things (defined as constants in {@link TTransportException}).
+     */
+    private static final Set<Integer> RESTARTABLE_CAUSES = ImmutableSet.of(TTransportException.NOT_OPEN,
+                                                                           TTransportException.END_OF_FILE,
+                                                                           TTransportException.TIMED_OUT,
+                                                                           TTransportException.UNKNOWN);
+
+    private SafeThriftClient() {
+        // ban constructor.
+    }
+
+    /**
+     * Reflectively wraps an already existing Thrift client.
+     *
+     * @param baseClient      the client to wrap
+     * @param clientInterface the interface that the client implements
+     * @param options         options that control behavior of the reconnecting client
+     * @param <T>
+     * @param <C>
+     * @return
+     */
+    public static <T extends TServiceClient, C> C wrap(T baseClient, Class<C> clientInterface, Options options) {
+        Object proxyObject = Proxy.newProxyInstance(clientInterface.getClassLoader(),
+                                                    new Class<?>[]{clientInterface},
+                                                    new ReconnectingClientProxy<T>(baseClient,
+                                                                                   options.getNumRetries(),
+                                                                                   options.getTimeBetweenRetries()));
+
+        return (C) proxyObject;
+    }
+
+    /**
+     * Reflectively wraps an already existing Thrift client.
+     *
+     * @param baseClient the client to wrap
+     * @param options    options that control behavior of the reconnecting client
+     * @param <T>
+     * @param <C>
+     * @return
+     */
+    public static <T extends TServiceClient, C> C wrap(T baseClient, Options options) {
+        Class<?>[] interfaces = baseClient.getClass().getInterfaces();
+
+        for (Class<?> iface : interfaces) {
+            if (iface.getSimpleName().equals("Iface")
+                    && iface.getEnclosingClass().equals(baseClient.getClass().getEnclosingClass())) {
+                return (C) wrap(baseClient, iface, options);
+            }
+        }
+
+        throw new RuntimeException("Class needs to implement Iface directly. Use wrap(TServiceClient, Class) instead.");
+    }
+
+    /**
+     * Reflectively wraps an already existing Thrift client.
+     *
+     * @param baseClient      the client to wrap
+     * @param clientInterface the interface that the client implements
+     * @param <T>
+     * @param <C>
+     * @return
+     */
+    public static <T extends TServiceClient, C> C wrap(T baseClient, Class<C> clientInterface) {
+        return wrap(baseClient, clientInterface, Options.defaults());
+    }
+
+    /**
+     * Reflectively wraps an already existing Thrift client.
+     *
+     * @param baseClient the client to wrap
+     * @param <T>
+     * @param <C>
+     * @return
+     */
+    public static <T extends TServiceClient, C> C wrap(T baseClient) {
+        return wrap(baseClient, Options.defaults());
+    }
+
+    /**
+     * Reconnection options for {@link SafeThriftClient}.
+     */
+    public static class Options {
+        private int numRetries;
+        private long timeBetweenRetries;
+
+        /**
+         * Creates new options with the given parameters.
+         *
+         * @param numRetries         the maximum number of times to try reconnecting before giving up and throwing an
+         *                           exception
+         * @param timeBetweenRetries the number of milliseconds to wait in between reconnection attempts.
+         */
+        public Options(int numRetries, long timeBetweenRetries) {
+            this.numRetries = numRetries;
+            this.timeBetweenRetries = timeBetweenRetries;
+        }
+
+        private static Options defaults() {
+            return new Options(5, 10000L);
+        }
+
+        private int getNumRetries() {
+            return numRetries;
+        }
+
+        private long getTimeBetweenRetries() {
+            return timeBetweenRetries;
+        }
+    }
+
+    /**
+     * Helper proxy class. Attempts to call method on proxy object wrapped in try/catch. If it fails, it attempts a
+     * reconnect and tries the method again.
+     *
+     * @param <T>
+     */
+    private static class ReconnectingClientProxy<T extends TServiceClient> implements InvocationHandler {
+        private final T baseClient;
+        private final int maxRetries;
+        private final long timeBetweenRetries;
+
+        public ReconnectingClientProxy(T baseClient, int maxRetries, long timeBetweenRetries) {
+            this.baseClient = baseClient;
+            this.maxRetries = maxRetries;
+            this.timeBetweenRetries = timeBetweenRetries;
+        }
+
+        private static void reconnectOrThrowException(TTransport transport, int maxRetries, long timeBetweenRetries)
+                throws TTransportException {
+            int errors = 0;
+            transport.close();
+
+            while (errors < maxRetries) {
+                try {
+                    LOG.debug("Attempting to reconnect...");
+                    transport.open();
+                    LOG.debug("Reconnection successful");
+                    break;
+                } catch (TTransportException e) {
+                    LOG.error("Error while reconnecting:", e);
+                    errors++;
+
+                    if (errors < maxRetries) {
+                        try {
+                            LOG.debug("Sleeping for {} milliseconds before retrying", timeBetweenRetries);
+                            Thread.sleep(timeBetweenRetries);
+                        } catch (InterruptedException e2) {
+                            Thread.currentThread().interrupt();
+                            throw new RuntimeException(e);
+                        }
+                    }
+                }
+            }
+
+            if (errors >= maxRetries) {
+                throw new TTransportException("Failed to reconnect");
+            }
+        }
+
+        @Override
+        public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
+
+            // With Thrift clients must be instantiated for each different transport session, i.e. server instance.
+            // Hence, using baseClient as lock, only calls towards the same server will be synchronized.
+
+            synchronized (baseClient) {
+
+                LOG.debug("Invoking client method... > method={}, fromThread={}",
+                          method.getName(), Thread.currentThread().getId());
+
+                Object result = null;
+
+                try {
+                    result = method.invoke(baseClient, args);
+
+                } catch (InvocationTargetException e) {
+                    if (e.getTargetException() instanceof TTransportException) {
+                        TTransportException cause = (TTransportException) e.getTargetException();
+
+                        if (RESTARTABLE_CAUSES.contains(cause.getType())) {
+                            reconnectOrThrowException(baseClient.getInputProtocol().getTransport(),
+                                                      maxRetries,
+                                                      timeBetweenRetries);
+                            result = method.invoke(baseClient, args);
+                        }
+                    }
+
+                    if (result == null) {
+                        LOG.debug("Exception while invoking client method: {} > method={}, fromThread={}",
+                                  e, method.getName(), Thread.currentThread().getId());
+                        throw e.getTargetException();
+                    }
+                }
+
+                LOG.debug("Method invoke complete! > method={}, fromThread={}",
+                          method.getName(), Thread.currentThread().getId());
+
+                return result;
+            }
+        }
+    }
+}