Adding TLS for NettyMessaging and configurable on NettyMessagingManager through JAVA_OPTS

Change-Id: I5e77658cbae70d3facbe9e1f56c9fa9fcf0e00cc
diff --git a/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java b/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
index 4e6f4b7..8c759d1 100644
--- a/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
+++ b/utils/netty/src/main/java/org/onlab/netty/NettyMessaging.java
@@ -35,7 +35,10 @@
 import io.netty.channel.socket.nio.NioServerSocketChannel;
 import io.netty.channel.socket.nio.NioSocketChannel;
 
+import java.io.FileInputStream;
 import java.io.IOException;
+import java.security.KeyStore;
+
 import java.util.Map;
 import java.util.concurrent.CompletableFuture;
 import java.util.concurrent.ConcurrentHashMap;
@@ -60,6 +63,11 @@
 import com.google.common.cache.RemovalListener;
 import com.google.common.cache.RemovalNotification;
 
+import javax.net.ssl.KeyManagerFactory;
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.TrustManagerFactory;
+
 /**
  * Implementation of MessagingService based on <a href="http://netty.io/">Netty</a> framework.
  */
@@ -93,6 +101,14 @@
     private Class<? extends ServerChannel> serverChannelClass;
     private Class<? extends Channel> clientChannelClass;
 
+    protected static final boolean TLS_DISABLED = false;
+    protected boolean enableNettyTLS = TLS_DISABLED;
+
+    protected String ksLocation;
+    protected String tsLocation;
+    protected char[] ksPwd;
+    protected char[] tsPwd;
+
     private void initEventLoopGroup() {
         // try Epoll first and if that does work, use nio.
         try {
@@ -216,9 +232,9 @@
             handler.apply(message.payload()).whenComplete((result, error) -> {
                 if (error == null) {
                     InternalMessage response = new InternalMessage(message.id(),
-                            localEp,
-                            REPLY_MESSAGE_TYPE,
-                            result);
+                                                                   localEp,
+                                                                   REPLY_MESSAGE_TYPE,
+                                                                   result);
                     sendAsync(message.sender(), response).whenComplete((r, e) -> {
                         if (e != null) {
                             log.debug("Failed to respond", e);
@@ -241,11 +257,15 @@
         b.option(ChannelOption.SO_RCVBUF, 1048576);
         b.option(ChannelOption.TCP_NODELAY, true);
         b.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
-        b.group(serverGroup, clientGroup)
-            .channel(serverChannelClass)
-            .childHandler(new OnosCommunicationChannelInitializer())
-            .option(ChannelOption.SO_BACKLOG, 128)
-            .childOption(ChannelOption.SO_KEEPALIVE, true);
+        b.group(serverGroup, clientGroup);
+        b.channel(serverChannelClass);
+        if (enableNettyTLS) {
+            b.childHandler(new SSLServerCommunicationChannelInitializer());
+        } else {
+            b.childHandler(new OnosCommunicationChannelInitializer());
+        }
+        b.option(ChannelOption.SO_BACKLOG, 128);
+        b.childOption(ChannelOption.SO_KEEPALIVE, true);
 
         // Bind and start to accept incoming connections.
         b.bind(localEp.port()).sync().addListener(future -> {
@@ -283,7 +303,11 @@
             // http://normanmaurer.me/presentations/2014-facebook-eng-netty/slides.html#37.0
             bootstrap.channel(clientChannelClass);
             bootstrap.option(ChannelOption.SO_KEEPALIVE, true);
-            bootstrap.handler(new OnosCommunicationChannelInitializer());
+            if (enableNettyTLS) {
+                bootstrap.handler(new SSLClientCommunicationChannelInitializer());
+            } else {
+                bootstrap.handler(new OnosCommunicationChannelInitializer());
+            }
             // Start the client.
             ChannelFuture f = bootstrap.connect(ep.host().toString(), ep.port()).sync();
             log.debug("Established a new connection to {}", ep);
@@ -301,6 +325,77 @@
         }
     }
 
+    private class SSLServerCommunicationChannelInitializer extends ChannelInitializer<SocketChannel> {
+
+        private final ChannelHandler dispatcher = new InboundMessageDispatcher();
+        private final ChannelHandler encoder = new MessageEncoder();
+
+        @Override
+        protected void initChannel(SocketChannel channel) throws Exception {
+            TrustManagerFactory tmFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+            KeyStore ts = KeyStore.getInstance("JKS");
+            ts.load(new FileInputStream(tsLocation), tsPwd);
+            tmFactory.init(ts);
+
+            KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+            KeyStore ks = KeyStore.getInstance("JKS");
+            ks.load(new FileInputStream(ksLocation), ksPwd);
+            kmf.init(ks, ksPwd);
+
+            SSLContext serverContext = SSLContext.getInstance("TLS");
+            serverContext.init(kmf.getKeyManagers(), tmFactory.getTrustManagers(), null);
+
+            SSLEngine serverSSLEngine = serverContext.createSSLEngine();
+
+            serverSSLEngine.setNeedClientAuth(true);
+            serverSSLEngine.setUseClientMode(false);
+            serverSSLEngine.setEnabledProtocols(serverSSLEngine.getSupportedProtocols());
+            serverSSLEngine.setEnabledCipherSuites(serverSSLEngine.getSupportedCipherSuites());
+            serverSSLEngine.setEnableSessionCreation(true);
+
+            channel.pipeline().addLast("ssl", new io.netty.handler.ssl.SslHandler(serverSSLEngine))
+                    .addLast("encoder", encoder)
+                    .addLast("decoder", new MessageDecoder())
+                    .addLast("handler", dispatcher);
+        }
+
+    }
+
+    private class SSLClientCommunicationChannelInitializer extends ChannelInitializer<SocketChannel> {
+
+        private final ChannelHandler dispatcher = new InboundMessageDispatcher();
+        private final ChannelHandler encoder = new MessageEncoder();
+
+        @Override
+        protected void initChannel(SocketChannel channel) throws Exception {
+            TrustManagerFactory tmFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+            KeyStore ts = KeyStore.getInstance("JKS");
+            ts.load(new FileInputStream(tsLocation), tsPwd);
+            tmFactory.init(ts);
+
+            KeyManagerFactory kmf = KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm());
+            KeyStore ks = KeyStore.getInstance("JKS");
+            ks.load(new FileInputStream(ksLocation), ksPwd);
+            kmf.init(ks, ksPwd);
+
+            SSLContext clientContext = SSLContext.getInstance("TLS");
+            clientContext.init(kmf.getKeyManagers(), tmFactory.getTrustManagers(), null);
+
+            SSLEngine clientSSLEngine = clientContext.createSSLEngine();
+
+            clientSSLEngine.setUseClientMode(true);
+            clientSSLEngine.setEnabledProtocols(clientSSLEngine.getSupportedProtocols());
+            clientSSLEngine.setEnabledCipherSuites(clientSSLEngine.getSupportedCipherSuites());
+            clientSSLEngine.setEnableSessionCreation(true);
+
+            channel.pipeline().addLast("ssl", new io.netty.handler.ssl.SslHandler(clientSSLEngine))
+                    .addLast("encoder", encoder)
+                    .addLast("decoder", new MessageDecoder())
+                    .addLast("handler", dispatcher);
+        }
+
+    }
+
     private class OnosCommunicationChannelInitializer extends ChannelInitializer<SocketChannel> {
 
         private final ChannelHandler dispatcher = new InboundMessageDispatcher();
@@ -308,10 +403,10 @@
 
         @Override
         protected void initChannel(SocketChannel channel) throws Exception {
-            channel.pipeline()
-                .addLast("encoder", encoder)
-                .addLast("decoder", new MessageDecoder())
-                .addLast("handler", dispatcher);
+                channel.pipeline()
+                        .addLast("encoder", encoder)
+                        .addLast("decoder", new MessageDecoder())
+                        .addLast("handler", dispatcher);
         }
     }