diff --git a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java
index 5e6cbd9..28d88e9 100644
--- a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java
+++ b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/Bmv2ThriftClient.java
@@ -236,16 +236,6 @@
         return buffers;
     }
 
-    private void closeTransport() {
-        LOG.debug("Closing transport session... > deviceId={}", deviceId);
-        if (this.transport.isOpen()) {
-            this.transport.close();
-            LOG.debug("Transport session closed! > deviceId={}", deviceId);
-        } else {
-            LOG.debug("Transport session was already closed! deviceId={}", deviceId);
-        }
-    }
-
     @Override
     public final long addTableEntry(Bmv2TableEntry entry) throws Bmv2RuntimeException {
 
@@ -527,7 +517,7 @@
             TTransport transport = new TSocket(
                     info.getLeft(), info.getRight());
             TProtocol protocol = new TBinaryProtocol(transport);
-            // Our BMv2 device implements multiple Thrift services, create a client for each one.
+            // Our BMv2 device implements multiple Thrift services, create a client for each one on the same transport.
             Standard.Client standardClient = new Standard.Client(
                     new TMultiplexedProtocol(protocol, "standard"));
             SimpleSwitch.Client simpleSwitch = new SimpleSwitch.Client(
@@ -551,11 +541,20 @@
             RemovalListener<DeviceId, Bmv2ThriftClient> {
 
         @Override
-        public void onRemoval(
-                RemovalNotification<DeviceId, Bmv2ThriftClient> notification) {
+        public void onRemoval(RemovalNotification<DeviceId, Bmv2ThriftClient> notification) {
             // close the transport connection
-            LOG.debug("Removing client from cache... > deviceId={}", notification.getKey());
-            notification.getValue().closeTransport();
+            Bmv2ThriftClient client = notification.getValue();
+            // Locking here is ugly, but needed (see SafeThriftClient).
+            synchronized (client.transport) {
+                LOG.debug("Closing transport session... > deviceId={}", client.deviceId);
+                if (client.transport.isOpen()) {
+                    client.transport.close();
+                    LOG.debug("Transport session closed! > deviceId={}", client.deviceId);
+                } else {
+                    LOG.debug("Transport session was already closed! deviceId={}", client.deviceId);
+                }
+            }
+            LOG.debug("Removing client from cache... > deviceId={}", client.deviceId);
         }
     }
 }
diff --git a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java
index 95a052a..98813f9 100644
--- a/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java
+++ b/protocols/bmv2/src/main/java/org/onosproject/bmv2/ctl/SafeThriftClient.java
@@ -35,8 +35,8 @@
 import java.util.Set;
 
 /**
- * Thrift client wrapper that attempts a few reconnects before giving up a method call execution. It al provides
- * synchronization between calls (automatically serialize multiple calls to the same client from different threads).
+ * Thrift client wrapper that attempts a few reconnects before giving up a method call execution. It also provides
+ * synchronization between calls over the same transport.
  */
 public final class SafeThriftClient {
 
@@ -161,23 +161,28 @@
      */
     private static class ReconnectingClientProxy<T extends TServiceClient> implements InvocationHandler {
         private final T baseClient;
+        private final TTransport transport;
         private final int maxRetries;
         private final long timeBetweenRetries;
 
         public ReconnectingClientProxy(T baseClient, int maxRetries, long timeBetweenRetries) {
             this.baseClient = baseClient;
+            this.transport = baseClient.getInputProtocol().getTransport();
             this.maxRetries = maxRetries;
             this.timeBetweenRetries = timeBetweenRetries;
         }
 
-        private static void reconnectOrThrowException(TTransport transport, int maxRetries, long timeBetweenRetries)
+        private void reconnectOrThrowException()
                 throws TTransportException {
             int errors = 0;
             try {
-                transport.close();
+                if (transport.isOpen()) {
+                    transport.close();
+                }
             } catch (Exception e) {
                 // Thrift seems to have a bug where if the transport is already closed a SocketException is thrown.
                 // However, such an exception is not advertised by .close(), hence the general-purpose catch.
+                LOG.debug("Exception while closing transport", e);
             }
 
             while (errors < maxRetries) {
@@ -210,42 +215,37 @@
         @Override
         public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
 
-            // With Thrift clients must be instantiated for each different transport session, i.e. server instance.
-            // Hence, using baseClient as lock, only calls towards the same server will be synchronized.
-
-            synchronized (baseClient) {
+            // Thrift transport layer is not thread-safe (it's a wrapper on a socket), hence we need locking.
+            synchronized (transport) {
 
                 LOG.debug("Invoking client method... > method={}, fromThread={}",
                           method.getName(), Thread.currentThread().getId());
 
-                Object result = null;
-
                 try {
-                    result = method.invoke(baseClient, args);
 
+                    return method.invoke(baseClient, args);
                 } catch (InvocationTargetException e) {
                     if (e.getTargetException() instanceof TTransportException) {
                         TTransportException cause = (TTransportException) e.getTargetException();
 
                         if (RESTARTABLE_CAUSES.contains(cause.getType())) {
-                            reconnectOrThrowException(baseClient.getInputProtocol().getTransport(),
-                                                      maxRetries,
-                                                      timeBetweenRetries);
-                            result = method.invoke(baseClient, args);
+                            // Try to reconnect. If fail, a TTransportException will be thrown.
+                            reconnectOrThrowException();
+                            try {
+                                // If here, transport has been successfully open, hence new exceptions will be thrown.
+                                return method.invoke(baseClient, args);
+                            } catch (InvocationTargetException e1) {
+                                LOG.debug("Exception while invoking client method: {} > method={}, fromThread={}",
+                                          e1, method.getName(), Thread.currentThread().getId());
+                                throw e1.getTargetException();
+                            }
                         }
                     }
-
-                    if (result == null) {
-                        LOG.debug("Exception while invoking client method: {} > method={}, fromThread={}",
-                                  e, method.getName(), Thread.currentThread().getId());
-                        throw e.getTargetException();
-                    }
+                    // Target exception is neither a TTransportException nor a restartable cause.
+                    LOG.debug("Exception while invoking client method: {} > method={}, fromThread={}",
+                              e, method.getName(), Thread.currentThread().getId());
+                    throw e.getTargetException();
                 }
-
-                LOG.debug("Method invoke complete! > method={}, fromThread={}",
-                          method.getName(), Thread.currentThread().getId());
-
-                return result;
             }
         }
     }
