blob: bf444236da557e421698f45cdd168c190baf0ab9 [file] [log] [blame]
/*
* Copyright 2016-present Open Networking Laboratory
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.onosproject.bmv2.ctl;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import org.apache.thrift.TProcessor;
import org.apache.thrift.server.TThreadedSelectorServer;
import org.apache.thrift.transport.TFramedTransport;
import org.apache.thrift.transport.TNonblockingServerSocket;
import org.apache.thrift.transport.TNonblockingServerTransport;
import org.apache.thrift.transport.TNonblockingSocket;
import org.apache.thrift.transport.TNonblockingTransport;
import org.apache.thrift.transport.TTransport;
import org.apache.thrift.transport.TTransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ExecutorService;
/**
* A Thrift TThreadedSelectorServer that keeps track of the clients' IP address.
*/
final class Bmv2ControlPlaneThriftServer extends TThreadedSelectorServer {
private static final int MAX_WORKER_THREADS = 20;
private static final int MAX_SELECTOR_THREADS = 4;
private static final int ACCEPT_QUEUE_LEN = 8;
private final Map<TTransport, InetAddress> clientAddresses = Maps.newConcurrentMap();
private final Set<TrackingSelectorThread> selectorThreads = Sets.newHashSet();
private AcceptThread acceptThread;
private final Logger log = LoggerFactory.getLogger(this.getClass());
/**
* Creates a new server.
*
* @param port a listening port
* @param processor a processor
* @param executorService an executor service
* @throws TTransportException
*/
public Bmv2ControlPlaneThriftServer(int port, TProcessor processor, ExecutorService executorService)
throws TTransportException {
super(new TThreadedSelectorServer.Args(new TNonblockingServerSocket(port))
.workerThreads(MAX_WORKER_THREADS)
.selectorThreads(MAX_SELECTOR_THREADS)
.acceptQueueSizePerThread(ACCEPT_QUEUE_LEN)
.executorService(executorService)
.processor(processor));
}
/**
* Returns the IP address of the client associated with the given input framed transport.
*
* @param inputTransport a framed transport instance
* @return the IP address of the client or null
*/
InetAddress getClientAddress(TFramedTransport inputTransport) {
return clientAddresses.get(inputTransport);
}
@Override
protected boolean startThreads() {
try {
for (int i = 0; i < MAX_SELECTOR_THREADS; ++i) {
selectorThreads.add(new TrackingSelectorThread(ACCEPT_QUEUE_LEN));
}
acceptThread = new AcceptThread((TNonblockingServerTransport) serverTransport_,
createSelectorThreadLoadBalancer(selectorThreads));
selectorThreads.forEach(Thread::start);
acceptThread.start();
return true;
} catch (IOException e) {
log.error("Failed to start threads!", e);
return false;
}
}
@Override
protected void joinThreads() throws InterruptedException {
// Wait until the io threads exit.
acceptThread.join();
for (TThreadedSelectorServer.SelectorThread thread : selectorThreads) {
thread.join();
}
}
@Override
public void stop() {
stopped_ = true;
// Stop queuing connect attempts asap.
stopListening();
if (acceptThread != null) {
acceptThread.wakeupSelector();
}
if (selectorThreads != null) {
selectorThreads.stream()
.filter(thread -> thread != null)
.forEach(TrackingSelectorThread::wakeupSelector);
}
}
private class TrackingSelectorThread extends TThreadedSelectorServer.SelectorThread {
TrackingSelectorThread(int maxPendingAccepts) throws IOException {
super(maxPendingAccepts);
}
@Override
protected FrameBuffer createFrameBuffer(TNonblockingTransport trans, SelectionKey selectionKey,
AbstractSelectThread selectThread) {
TrackingFrameBuffer frameBuffer = new TrackingFrameBuffer(trans, selectionKey, selectThread);
if (trans instanceof TNonblockingSocket) {
try {
SocketChannel socketChannel = ((TNonblockingSocket) trans).getSocketChannel();
InetAddress addr = ((InetSocketAddress) socketChannel.getRemoteAddress()).getAddress();
clientAddresses.put(frameBuffer.getInputFramedTransport(), addr);
} catch (IOException e) {
log.warn("Exception while tracking client address", e);
clientAddresses.remove(frameBuffer.getInputFramedTransport());
}
} else {
log.warn("Unknown TNonblockingTransport instance: {}", trans.getClass().getName());
clientAddresses.remove(frameBuffer.getInputFramedTransport());
}
return frameBuffer;
}
}
private class TrackingFrameBuffer extends FrameBuffer {
TrackingFrameBuffer(TNonblockingTransport trans, SelectionKey selectionKey,
AbstractSelectThread selectThread) {
super(trans, selectionKey, selectThread);
}
TTransport getInputFramedTransport() {
return this.inTrans_;
}
}
}