| /* |
| * Copyright 2016-present Open Networking Laboratory |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.onosproject.bmv2.ctl; |
| |
| import com.google.common.collect.Maps; |
| import com.google.common.collect.Sets; |
| import org.apache.thrift.TProcessor; |
| import org.apache.thrift.server.TThreadedSelectorServer; |
| import org.apache.thrift.transport.TFramedTransport; |
| import org.apache.thrift.transport.TNonblockingServerSocket; |
| import org.apache.thrift.transport.TNonblockingServerTransport; |
| import org.apache.thrift.transport.TNonblockingSocket; |
| import org.apache.thrift.transport.TNonblockingTransport; |
| import org.apache.thrift.transport.TTransport; |
| import org.apache.thrift.transport.TTransportException; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import java.io.IOException; |
| import java.net.InetAddress; |
| import java.net.InetSocketAddress; |
| import java.nio.channels.SelectionKey; |
| import java.nio.channels.SocketChannel; |
| import java.util.Map; |
| import java.util.Set; |
| import java.util.concurrent.ExecutorService; |
| |
| /** |
| * A Thrift TThreadedSelectorServer that keeps track of the clients' IP address. |
| */ |
| final class Bmv2ControlPlaneThriftServer extends TThreadedSelectorServer { |
| |
| private static final int MAX_WORKER_THREADS = 20; |
| private static final int MAX_SELECTOR_THREADS = 4; |
| private static final int ACCEPT_QUEUE_LEN = 8; |
| |
| private final Map<TTransport, InetAddress> clientAddresses = Maps.newConcurrentMap(); |
| private final Set<TrackingSelectorThread> selectorThreads = Sets.newHashSet(); |
| |
| private AcceptThread acceptThread; |
| |
| private final Logger log = LoggerFactory.getLogger(this.getClass()); |
| |
| /** |
| * Creates a new server. |
| * |
| * @param port a listening port |
| * @param processor a processor |
| * @param executorService an executor service |
| * @throws TTransportException |
| */ |
| public Bmv2ControlPlaneThriftServer(int port, TProcessor processor, ExecutorService executorService) |
| throws TTransportException { |
| super(new TThreadedSelectorServer.Args(new TNonblockingServerSocket(port)) |
| .workerThreads(MAX_WORKER_THREADS) |
| .selectorThreads(MAX_SELECTOR_THREADS) |
| .acceptQueueSizePerThread(ACCEPT_QUEUE_LEN) |
| .executorService(executorService) |
| .processor(processor)); |
| } |
| |
| /** |
| * Returns the IP address of the client associated with the given input framed transport. |
| * |
| * @param inputTransport a framed transport instance |
| * @return the IP address of the client or null |
| */ |
| InetAddress getClientAddress(TFramedTransport inputTransport) { |
| return clientAddresses.get(inputTransport); |
| } |
| |
| @Override |
| protected boolean startThreads() { |
| try { |
| for (int i = 0; i < MAX_SELECTOR_THREADS; ++i) { |
| selectorThreads.add(new TrackingSelectorThread(ACCEPT_QUEUE_LEN)); |
| } |
| acceptThread = new AcceptThread((TNonblockingServerTransport) serverTransport_, |
| createSelectorThreadLoadBalancer(selectorThreads)); |
| selectorThreads.forEach(Thread::start); |
| acceptThread.start(); |
| return true; |
| } catch (IOException e) { |
| log.error("Failed to start threads!", e); |
| return false; |
| } |
| } |
| |
| @Override |
| protected void joinThreads() throws InterruptedException { |
| // Wait until the io threads exit. |
| acceptThread.join(); |
| for (TThreadedSelectorServer.SelectorThread thread : selectorThreads) { |
| thread.join(); |
| } |
| } |
| |
| @Override |
| public void stop() { |
| stopped_ = true; |
| // Stop queuing connect attempts asap. |
| stopListening(); |
| if (acceptThread != null) { |
| acceptThread.wakeupSelector(); |
| } |
| if (selectorThreads != null) { |
| selectorThreads.stream() |
| .filter(thread -> thread != null) |
| .forEach(TrackingSelectorThread::wakeupSelector); |
| } |
| } |
| |
| private class TrackingSelectorThread extends TThreadedSelectorServer.SelectorThread { |
| |
| TrackingSelectorThread(int maxPendingAccepts) throws IOException { |
| super(maxPendingAccepts); |
| } |
| |
| @Override |
| protected FrameBuffer createFrameBuffer(TNonblockingTransport trans, SelectionKey selectionKey, |
| AbstractSelectThread selectThread) { |
| TrackingFrameBuffer frameBuffer = new TrackingFrameBuffer(trans, selectionKey, selectThread); |
| if (trans instanceof TNonblockingSocket) { |
| try { |
| SocketChannel socketChannel = ((TNonblockingSocket) trans).getSocketChannel(); |
| InetAddress addr = ((InetSocketAddress) socketChannel.getRemoteAddress()).getAddress(); |
| clientAddresses.put(frameBuffer.getInputFramedTransport(), addr); |
| } catch (IOException e) { |
| log.warn("Exception while tracking client address", e); |
| clientAddresses.remove(frameBuffer.getInputFramedTransport()); |
| } |
| } else { |
| log.warn("Unknown TNonblockingTransport instance: {}", trans.getClass().getName()); |
| clientAddresses.remove(frameBuffer.getInputFramedTransport()); |
| } |
| return frameBuffer; |
| } |
| } |
| |
| private class TrackingFrameBuffer extends FrameBuffer { |
| |
| TrackingFrameBuffer(TNonblockingTransport trans, SelectionKey selectionKey, |
| AbstractSelectThread selectThread) { |
| super(trans, selectionKey, selectThread); |
| } |
| |
| TTransport getInputFramedTransport() { |
| return this.inTrans_; |
| } |
| } |
| } |