blob: bf444236da557e421698f45cdd168c190baf0ab9 [file] [log] [blame]
Carmelo Cascone62f1e1e2016-06-22 01:43:49 -07001/*
2 * Copyright 2016-present Open Networking Laboratory
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17package org.onosproject.bmv2.ctl;
18
19import com.google.common.collect.Maps;
20import com.google.common.collect.Sets;
21import org.apache.thrift.TProcessor;
22import org.apache.thrift.server.TThreadedSelectorServer;
23import org.apache.thrift.transport.TFramedTransport;
24import org.apache.thrift.transport.TNonblockingServerSocket;
25import org.apache.thrift.transport.TNonblockingServerTransport;
26import org.apache.thrift.transport.TNonblockingSocket;
27import org.apache.thrift.transport.TNonblockingTransport;
28import org.apache.thrift.transport.TTransport;
29import org.apache.thrift.transport.TTransportException;
30import org.slf4j.Logger;
31import org.slf4j.LoggerFactory;
32
33import java.io.IOException;
34import java.net.InetAddress;
35import java.net.InetSocketAddress;
36import java.nio.channels.SelectionKey;
37import java.nio.channels.SocketChannel;
38import java.util.Map;
39import java.util.Set;
40import java.util.concurrent.ExecutorService;
41
42/**
43 * A Thrift TThreadedSelectorServer that keeps track of the clients' IP address.
44 */
45final class Bmv2ControlPlaneThriftServer extends TThreadedSelectorServer {
46
47 private static final int MAX_WORKER_THREADS = 20;
48 private static final int MAX_SELECTOR_THREADS = 4;
49 private static final int ACCEPT_QUEUE_LEN = 8;
50
51 private final Map<TTransport, InetAddress> clientAddresses = Maps.newConcurrentMap();
52 private final Set<TrackingSelectorThread> selectorThreads = Sets.newHashSet();
53
54 private AcceptThread acceptThread;
55
56 private final Logger log = LoggerFactory.getLogger(this.getClass());
57
58 /**
59 * Creates a new server.
60 *
61 * @param port a listening port
62 * @param processor a processor
63 * @param executorService an executor service
64 * @throws TTransportException
65 */
66 public Bmv2ControlPlaneThriftServer(int port, TProcessor processor, ExecutorService executorService)
67 throws TTransportException {
68 super(new TThreadedSelectorServer.Args(new TNonblockingServerSocket(port))
69 .workerThreads(MAX_WORKER_THREADS)
70 .selectorThreads(MAX_SELECTOR_THREADS)
71 .acceptQueueSizePerThread(ACCEPT_QUEUE_LEN)
72 .executorService(executorService)
73 .processor(processor));
74 }
75
76 /**
77 * Returns the IP address of the client associated with the given input framed transport.
78 *
79 * @param inputTransport a framed transport instance
80 * @return the IP address of the client or null
81 */
82 InetAddress getClientAddress(TFramedTransport inputTransport) {
83 return clientAddresses.get(inputTransport);
84 }
85
86 @Override
87 protected boolean startThreads() {
88 try {
89 for (int i = 0; i < MAX_SELECTOR_THREADS; ++i) {
90 selectorThreads.add(new TrackingSelectorThread(ACCEPT_QUEUE_LEN));
91 }
92 acceptThread = new AcceptThread((TNonblockingServerTransport) serverTransport_,
93 createSelectorThreadLoadBalancer(selectorThreads));
94 selectorThreads.forEach(Thread::start);
95 acceptThread.start();
96 return true;
97 } catch (IOException e) {
98 log.error("Failed to start threads!", e);
99 return false;
100 }
101 }
102
103 @Override
104 protected void joinThreads() throws InterruptedException {
105 // Wait until the io threads exit.
106 acceptThread.join();
107 for (TThreadedSelectorServer.SelectorThread thread : selectorThreads) {
108 thread.join();
109 }
110 }
111
112 @Override
113 public void stop() {
114 stopped_ = true;
115 // Stop queuing connect attempts asap.
116 stopListening();
117 if (acceptThread != null) {
118 acceptThread.wakeupSelector();
119 }
120 if (selectorThreads != null) {
121 selectorThreads.stream()
122 .filter(thread -> thread != null)
123 .forEach(TrackingSelectorThread::wakeupSelector);
124 }
125 }
126
127 private class TrackingSelectorThread extends TThreadedSelectorServer.SelectorThread {
128
129 TrackingSelectorThread(int maxPendingAccepts) throws IOException {
130 super(maxPendingAccepts);
131 }
132
133 @Override
134 protected FrameBuffer createFrameBuffer(TNonblockingTransport trans, SelectionKey selectionKey,
135 AbstractSelectThread selectThread) {
136 TrackingFrameBuffer frameBuffer = new TrackingFrameBuffer(trans, selectionKey, selectThread);
137 if (trans instanceof TNonblockingSocket) {
138 try {
139 SocketChannel socketChannel = ((TNonblockingSocket) trans).getSocketChannel();
140 InetAddress addr = ((InetSocketAddress) socketChannel.getRemoteAddress()).getAddress();
141 clientAddresses.put(frameBuffer.getInputFramedTransport(), addr);
142 } catch (IOException e) {
143 log.warn("Exception while tracking client address", e);
144 clientAddresses.remove(frameBuffer.getInputFramedTransport());
145 }
146 } else {
147 log.warn("Unknown TNonblockingTransport instance: {}", trans.getClass().getName());
148 clientAddresses.remove(frameBuffer.getInputFramedTransport());
149 }
150 return frameBuffer;
151 }
152 }
153
154 private class TrackingFrameBuffer extends FrameBuffer {
155
156 TrackingFrameBuffer(TNonblockingTransport trans, SelectionKey selectionKey,
157 AbstractSelectThread selectThread) {
158 super(trans, selectionKey, selectThread);
159 }
160
161 TTransport getInputFramedTransport() {
162 return this.inTrans_;
163 }
164 }
165}