summaryrefslogtreecommitdiffhomepage
path: root/src/java/nginx/unit/websocket/WsWebSocketContainer.java
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/java/nginx/unit/websocket/WsWebSocketContainer.java1123
1 files changed, 1123 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/WsWebSocketContainer.java b/src/java/nginx/unit/websocket/WsWebSocketContainer.java
new file mode 100644
index 00000000..282665ef
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsWebSocketContainer.java
@@ -0,0 +1,1123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.EOFException;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.InetSocketAddress;
+import java.net.Proxy;
+import java.net.ProxySelector;
+import java.net.SocketAddress;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.ByteBuffer;
+import java.nio.channels.AsynchronousChannelGroup;
+import java.nio.channels.AsynchronousSocketChannel;
+import java.nio.charset.StandardCharsets;
+import java.security.KeyStore;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLException;
+import javax.net.ssl.SSLParameters;
+import javax.net.ssl.TrustManagerFactory;
+import javax.websocket.ClientEndpoint;
+import javax.websocket.ClientEndpointConfig;
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.DeploymentException;
+import javax.websocket.Endpoint;
+import javax.websocket.Extension;
+import javax.websocket.HandshakeResponse;
+import javax.websocket.Session;
+import javax.websocket.WebSocketContainer;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.InstanceManager;
+import org.apache.tomcat.util.buf.StringUtils;
+import org.apache.tomcat.util.codec.binary.Base64;
+import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
+import org.apache.tomcat.util.res.StringManager;
+import nginx.unit.websocket.pojo.PojoEndpointClient;
+
+public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess {
+
+ private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class);
+ private static final Random RANDOM = new Random();
+ private static final byte[] CRLF = new byte[] { 13, 10 };
+
+ private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1);
+ private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1);
+ private static final byte[] HTTP_VERSION_BYTES =
+ " HTTP/1.1\r\n".getBytes(StandardCharsets.ISO_8859_1);
+
+ private volatile AsynchronousChannelGroup asynchronousChannelGroup = null;
+ private final Object asynchronousChannelGroupLock = new Object();
+
+ private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static
+ private final Map<Endpoint, Set<WsSession>> endpointSessionMap =
+ new HashMap<>();
+ private final Map<WsSession,WsSession> sessions = new ConcurrentHashMap<>();
+ private final Object endPointSessionMapLock = new Object();
+
+ private long defaultAsyncTimeout = -1;
+ private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
+ private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
+ private volatile long defaultMaxSessionIdleTimeout = 0;
+ private int backgroundProcessCount = 0;
+ private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD;
+
+ private InstanceManager instanceManager;
+
+ InstanceManager getInstanceManager() {
+ return instanceManager;
+ }
+
+ protected void setInstanceManager(InstanceManager instanceManager) {
+ this.instanceManager = instanceManager;
+ }
+
+ @Override
+ public Session connectToServer(Object pojo, URI path)
+ throws DeploymentException {
+
+ ClientEndpoint annotation =
+ pojo.getClass().getAnnotation(ClientEndpoint.class);
+ if (annotation == null) {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.missingAnnotation",
+ pojo.getClass().getName()));
+ }
+
+ Endpoint ep = new PojoEndpointClient(pojo, Arrays.asList(annotation.decoders()));
+
+ Class<? extends ClientEndpointConfig.Configurator> configuratorClazz =
+ annotation.configurator();
+
+ ClientEndpointConfig.Configurator configurator = null;
+ if (!ClientEndpointConfig.Configurator.class.equals(
+ configuratorClazz)) {
+ try {
+ configurator = configuratorClazz.getConstructor().newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.defaultConfiguratorFail"), e);
+ }
+ }
+
+ ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create();
+ // Avoid NPE when using RI API JAR - see BZ 56343
+ if (configurator != null) {
+ builder.configurator(configurator);
+ }
+ ClientEndpointConfig config = builder.
+ decoders(Arrays.asList(annotation.decoders())).
+ encoders(Arrays.asList(annotation.encoders())).
+ preferredSubprotocols(Arrays.asList(annotation.subprotocols())).
+ build();
+ return connectToServer(ep, config, path);
+ }
+
+
+ @Override
+ public Session connectToServer(Class<?> annotatedEndpointClass, URI path)
+ throws DeploymentException {
+
+ Object pojo;
+ try {
+ pojo = annotatedEndpointClass.getConstructor().newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.endpointCreateFail",
+ annotatedEndpointClass.getName()), e);
+ }
+
+ return connectToServer(pojo, path);
+ }
+
+
+ @Override
+ public Session connectToServer(Class<? extends Endpoint> clazz,
+ ClientEndpointConfig clientEndpointConfiguration, URI path)
+ throws DeploymentException {
+
+ Endpoint endpoint;
+ try {
+ endpoint = clazz.getConstructor().newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.endpointCreateFail", clazz.getName()),
+ e);
+ }
+
+ return connectToServer(endpoint, clientEndpointConfiguration, path);
+ }
+
+
+ @Override
+ public Session connectToServer(Endpoint endpoint,
+ ClientEndpointConfig clientEndpointConfiguration, URI path)
+ throws DeploymentException {
+ return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, new HashSet<>());
+ }
+
+ private Session connectToServerRecursive(Endpoint endpoint,
+ ClientEndpointConfig clientEndpointConfiguration, URI path,
+ Set<URI> redirectSet)
+ throws DeploymentException {
+
+ boolean secure = false;
+ ByteBuffer proxyConnect = null;
+ URI proxyPath;
+
+ // Validate scheme (and build proxyPath)
+ String scheme = path.getScheme();
+ if ("ws".equalsIgnoreCase(scheme)) {
+ proxyPath = URI.create("http" + path.toString().substring(2));
+ } else if ("wss".equalsIgnoreCase(scheme)) {
+ proxyPath = URI.create("https" + path.toString().substring(3));
+ secure = true;
+ } else {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.pathWrongScheme", scheme));
+ }
+
+ // Validate host
+ String host = path.getHost();
+ if (host == null) {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.pathNoHost"));
+ }
+ int port = path.getPort();
+
+ SocketAddress sa = null;
+
+ // Check to see if a proxy is configured. Javadoc indicates return value
+ // will never be null
+ List<Proxy> proxies = ProxySelector.getDefault().select(proxyPath);
+ Proxy selectedProxy = null;
+ for (Proxy proxy : proxies) {
+ if (proxy.type().equals(Proxy.Type.HTTP)) {
+ sa = proxy.address();
+ if (sa instanceof InetSocketAddress) {
+ InetSocketAddress inet = (InetSocketAddress) sa;
+ if (inet.isUnresolved()) {
+ sa = new InetSocketAddress(inet.getHostName(), inet.getPort());
+ }
+ }
+ selectedProxy = proxy;
+ break;
+ }
+ }
+
+ // If the port is not explicitly specified, compute it based on the
+ // scheme
+ if (port == -1) {
+ if ("ws".equalsIgnoreCase(scheme)) {
+ port = 80;
+ } else {
+ // Must be wss due to scheme validation above
+ port = 443;
+ }
+ }
+
+ // If sa is null, no proxy is configured so need to create sa
+ if (sa == null) {
+ sa = new InetSocketAddress(host, port);
+ } else {
+ proxyConnect = createProxyRequest(host, port);
+ }
+
+ // Create the initial HTTP request to open the WebSocket connection
+ Map<String, List<String>> reqHeaders = createRequestHeaders(host, port,
+ clientEndpointConfiguration);
+ clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders);
+ if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null
+ && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) {
+ List<String> originValues = new ArrayList<>(1);
+ originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE);
+ reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues);
+ }
+ ByteBuffer request = createRequest(path, reqHeaders);
+
+ AsynchronousSocketChannel socketChannel;
+ try {
+ socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup());
+ } catch (IOException ioe) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.asynchronousSocketChannelFail"), ioe);
+ }
+
+ Map<String,Object> userProperties = clientEndpointConfiguration.getUserProperties();
+
+ // Get the connection timeout
+ long timeout = Constants.IO_TIMEOUT_MS_DEFAULT;
+ String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY);
+ if (timeoutValue != null) {
+ timeout = Long.valueOf(timeoutValue).intValue();
+ }
+
+ // Set-up
+ // Same size as the WsFrame input buffer
+ ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize());
+ String subProtocol;
+ boolean success = false;
+ List<Extension> extensionsAgreed = new ArrayList<>();
+ Transformation transformation = null;
+
+ // Open the connection
+ Future<Void> fConnect = socketChannel.connect(sa);
+ AsyncChannelWrapper channel = null;
+
+ if (proxyConnect != null) {
+ try {
+ fConnect.get(timeout, TimeUnit.MILLISECONDS);
+ // Proxy CONNECT is clear text
+ channel = new AsyncChannelWrapperNonSecure(socketChannel);
+ writeRequest(channel, proxyConnect, timeout);
+ HttpResponse httpResponse = processResponse(response, channel, timeout);
+ if (httpResponse.getStatus() != 200) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.proxyConnectFail", selectedProxy,
+ Integer.toString(httpResponse.getStatus())));
+ }
+ } catch (TimeoutException | InterruptedException | ExecutionException |
+ EOFException e) {
+ if (channel != null) {
+ channel.close();
+ }
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.httpRequestFailed"), e);
+ }
+ }
+
+ if (secure) {
+ // Regardless of whether a non-secure wrapper was created for a
+ // proxy CONNECT, need to use TLS from this point on so wrap the
+ // original AsynchronousSocketChannel
+ SSLEngine sslEngine = createSSLEngine(userProperties, host, port);
+ channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine);
+ } else if (channel == null) {
+ // Only need to wrap as this point if it wasn't wrapped to process a
+ // proxy CONNECT
+ channel = new AsyncChannelWrapperNonSecure(socketChannel);
+ }
+
+ try {
+ fConnect.get(timeout, TimeUnit.MILLISECONDS);
+
+ Future<Void> fHandshake = channel.handshake();
+ fHandshake.get(timeout, TimeUnit.MILLISECONDS);
+
+ writeRequest(channel, request, timeout);
+
+ HttpResponse httpResponse = processResponse(response, channel, timeout);
+
+ // Check maximum permitted redirects
+ int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT;
+ String maxRedirectsValue =
+ (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY);
+ if (maxRedirectsValue != null) {
+ maxRedirects = Integer.parseInt(maxRedirectsValue);
+ }
+
+ if (httpResponse.status != 101) {
+ if(isRedirectStatus(httpResponse.status)){
+ List<String> locationHeader =
+ httpResponse.getHandshakeResponse().getHeaders().get(
+ Constants.LOCATION_HEADER_NAME);
+
+ if (locationHeader == null || locationHeader.isEmpty() ||
+ locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.missingLocationHeader",
+ Integer.toString(httpResponse.status)));
+ }
+
+ URI redirectLocation = URI.create(locationHeader.get(0)).normalize();
+
+ if (!redirectLocation.isAbsolute()) {
+ redirectLocation = path.resolve(redirectLocation);
+ }
+
+ String redirectScheme = redirectLocation.getScheme().toLowerCase();
+
+ if (redirectScheme.startsWith("http")) {
+ redirectLocation = new URI(redirectScheme.replace("http", "ws"),
+ redirectLocation.getUserInfo(), redirectLocation.getHost(),
+ redirectLocation.getPort(), redirectLocation.getPath(),
+ redirectLocation.getQuery(), redirectLocation.getFragment());
+ }
+
+ if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.redirectThreshold", redirectLocation,
+ Integer.toString(redirectSet.size()),
+ Integer.toString(maxRedirects)));
+ }
+
+ return connectToServerRecursive(endpoint, clientEndpointConfiguration, redirectLocation, redirectSet);
+
+ }
+
+ else if (httpResponse.status == 401) {
+
+ if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.failedAuthentication",
+ Integer.valueOf(httpResponse.status)));
+ }
+
+ List<String> wwwAuthenticateHeaders = httpResponse.getHandshakeResponse()
+ .getHeaders().get(Constants.WWW_AUTHENTICATE_HEADER_NAME);
+
+ if (wwwAuthenticateHeaders == null || wwwAuthenticateHeaders.isEmpty() ||
+ wwwAuthenticateHeaders.get(0) == null || wwwAuthenticateHeaders.get(0).isEmpty()) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.missingWWWAuthenticateHeader",
+ Integer.toString(httpResponse.status)));
+ }
+
+ String authScheme = wwwAuthenticateHeaders.get(0).split("\\s+", 2)[0];
+ String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1)
+ .split("\\s", 3)[1];
+
+ Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme);
+
+ if (auth == null) {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.unsupportedAuthScheme",
+ Integer.valueOf(httpResponse.status), authScheme));
+ }
+
+ userProperties.put(Constants.AUTHORIZATION_HEADER_NAME, auth.getAuthorization(
+ requestUri, wwwAuthenticateHeaders.get(0), userProperties));
+
+ return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, redirectSet);
+
+ }
+
+ else {
+ throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus",
+ Integer.toString(httpResponse.status)));
+ }
+ }
+ HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse();
+ clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse);
+
+ // Sub-protocol
+ List<String> protocolHeaders = handshakeResponse.getHeaders().get(
+ Constants.WS_PROTOCOL_HEADER_NAME);
+ if (protocolHeaders == null || protocolHeaders.size() == 0) {
+ subProtocol = null;
+ } else if (protocolHeaders.size() == 1) {
+ subProtocol = protocolHeaders.get(0);
+ } else {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.invalidSubProtocol"));
+ }
+
+ // Extensions
+ // Should normally only be one header but handle the case of
+ // multiple headers
+ List<String> extHeaders = handshakeResponse.getHeaders().get(
+ Constants.WS_EXTENSIONS_HEADER_NAME);
+ if (extHeaders != null) {
+ for (String extHeader : extHeaders) {
+ Util.parseExtensionHeader(extensionsAgreed, extHeader);
+ }
+ }
+
+ // Build the transformations
+ TransformationFactory factory = TransformationFactory.getInstance();
+ for (Extension extension : extensionsAgreed) {
+ List<List<Extension.Parameter>> wrapper = new ArrayList<>(1);
+ wrapper.add(extension.getParameters());
+ Transformation t = factory.create(extension.getName(), wrapper, false);
+ if (t == null) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.invalidExtensionParameters"));
+ }
+ if (transformation == null) {
+ transformation = t;
+ } else {
+ transformation.setNext(t);
+ }
+ }
+
+ success = true;
+ } catch (ExecutionException | InterruptedException | SSLException |
+ EOFException | TimeoutException | URISyntaxException | AuthenticationException e) {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.httpRequestFailed"), e);
+ } finally {
+ if (!success) {
+ channel.close();
+ }
+ }
+
+ // Switch to WebSocket
+ WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel);
+
+ WsSession wsSession = new WsSession(endpoint, wsRemoteEndpointClient,
+ this, null, null, null, null, null, extensionsAgreed,
+ subProtocol, Collections.<String,String>emptyMap(), secure,
+ clientEndpointConfiguration, null);
+
+ WsFrameClient wsFrameClient = new WsFrameClient(response, channel,
+ wsSession, transformation);
+ // WsFrame adds the necessary final transformations. Copy the
+ // completed transformation chain to the remote end point.
+ wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation());
+
+ endpoint.onOpen(wsSession, clientEndpointConfiguration);
+ registerSession(endpoint, wsSession);
+
+ /* It is possible that the server sent one or more messages as soon as
+ * the WebSocket connection was established. Depending on the exact
+ * timing of when those messages were sent they could be sat in the
+ * input buffer waiting to be read and will not trigger a "data
+ * available to read" event. Therefore, it is necessary to process the
+ * input buffer here. Note that this happens on the current thread which
+ * means that this thread will be used for any onMessage notifications.
+ * This is a special case. Subsequent "data available to read" events
+ * will be handled by threads from the AsyncChannelGroup's executor.
+ */
+ wsFrameClient.startInputProcessing();
+
+ return wsSession;
+ }
+
+
+ private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request,
+ long timeout) throws TimeoutException, InterruptedException, ExecutionException {
+ int toWrite = request.limit();
+
+ Future<Integer> fWrite = channel.write(request);
+ Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS);
+ toWrite -= thisWrite.intValue();
+
+ while (toWrite > 0) {
+ fWrite = channel.write(request);
+ thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS);
+ toWrite -= thisWrite.intValue();
+ }
+ }
+
+
+ private static boolean isRedirectStatus(int httpResponseCode) {
+
+ boolean isRedirect = false;
+
+ switch (httpResponseCode) {
+ case Constants.MULTIPLE_CHOICES:
+ case Constants.MOVED_PERMANENTLY:
+ case Constants.FOUND:
+ case Constants.SEE_OTHER:
+ case Constants.USE_PROXY:
+ case Constants.TEMPORARY_REDIRECT:
+ isRedirect = true;
+ break;
+ default:
+ break;
+ }
+
+ return isRedirect;
+ }
+
+
+ private static ByteBuffer createProxyRequest(String host, int port) {
+ StringBuilder request = new StringBuilder();
+ request.append("CONNECT ");
+ request.append(host);
+ request.append(':');
+ request.append(port);
+
+ request.append(" HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keepalive\r\nHost: ");
+ request.append(host);
+ request.append(':');
+ request.append(port);
+
+ request.append("\r\n\r\n");
+
+ byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1);
+ return ByteBuffer.wrap(bytes);
+ }
+
+ protected void registerSession(Endpoint endpoint, WsSession wsSession) {
+
+ if (!wsSession.isOpen()) {
+ // The session was closed during onOpen. No need to register it.
+ return;
+ }
+ synchronized (endPointSessionMapLock) {
+ if (endpointSessionMap.size() == 0) {
+ BackgroundProcessManager.getInstance().register(this);
+ }
+ Set<WsSession> wsSessions = endpointSessionMap.get(endpoint);
+ if (wsSessions == null) {
+ wsSessions = new HashSet<>();
+ endpointSessionMap.put(endpoint, wsSessions);
+ }
+ wsSessions.add(wsSession);
+ }
+ sessions.put(wsSession, wsSession);
+ }
+
+
+ protected void unregisterSession(Endpoint endpoint, WsSession wsSession) {
+
+ synchronized (endPointSessionMapLock) {
+ Set<WsSession> wsSessions = endpointSessionMap.get(endpoint);
+ if (wsSessions != null) {
+ wsSessions.remove(wsSession);
+ if (wsSessions.size() == 0) {
+ endpointSessionMap.remove(endpoint);
+ }
+ }
+ if (endpointSessionMap.size() == 0) {
+ BackgroundProcessManager.getInstance().unregister(this);
+ }
+ }
+ sessions.remove(wsSession);
+ }
+
+
+ Set<Session> getOpenSessions(Endpoint endpoint) {
+ HashSet<Session> result = new HashSet<>();
+ synchronized (endPointSessionMapLock) {
+ Set<WsSession> sessions = endpointSessionMap.get(endpoint);
+ if (sessions != null) {
+ result.addAll(sessions);
+ }
+ }
+ return result;
+ }
+
+ private static Map<String, List<String>> createRequestHeaders(String host, int port,
+ ClientEndpointConfig clientEndpointConfiguration) {
+
+ Map<String, List<String>> headers = new HashMap<>();
+ List<Extension> extensions = clientEndpointConfiguration.getExtensions();
+ List<String> subProtocols = clientEndpointConfiguration.getPreferredSubprotocols();
+ Map<String, Object> userProperties = clientEndpointConfiguration.getUserProperties();
+
+ if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) {
+ List<String> authValues = new ArrayList<>(1);
+ authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME));
+ headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues);
+ }
+
+ // Host header
+ List<String> hostValues = new ArrayList<>(1);
+ if (port == -1) {
+ hostValues.add(host);
+ } else {
+ hostValues.add(host + ':' + port);
+ }
+
+ headers.put(Constants.HOST_HEADER_NAME, hostValues);
+
+ // Upgrade header
+ List<String> upgradeValues = new ArrayList<>(1);
+ upgradeValues.add(Constants.UPGRADE_HEADER_VALUE);
+ headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues);
+
+ // Connection header
+ List<String> connectionValues = new ArrayList<>(1);
+ connectionValues.add(Constants.CONNECTION_HEADER_VALUE);
+ headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues);
+
+ // WebSocket version header
+ List<String> wsVersionValues = new ArrayList<>(1);
+ wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE);
+ headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues);
+
+ // WebSocket key
+ List<String> wsKeyValues = new ArrayList<>(1);
+ wsKeyValues.add(generateWsKeyValue());
+ headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues);
+
+ // WebSocket sub-protocols
+ if (subProtocols != null && subProtocols.size() > 0) {
+ headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols);
+ }
+
+ // WebSocket extensions
+ if (extensions != null && extensions.size() > 0) {
+ headers.put(Constants.WS_EXTENSIONS_HEADER_NAME,
+ generateExtensionHeaders(extensions));
+ }
+
+ return headers;
+ }
+
+
+ private static List<String> generateExtensionHeaders(List<Extension> extensions) {
+ List<String> result = new ArrayList<>(extensions.size());
+ for (Extension extension : extensions) {
+ StringBuilder header = new StringBuilder();
+ header.append(extension.getName());
+ for (Extension.Parameter param : extension.getParameters()) {
+ header.append(';');
+ header.append(param.getName());
+ String value = param.getValue();
+ if (value != null && value.length() > 0) {
+ header.append('=');
+ header.append(value);
+ }
+ }
+ result.add(header.toString());
+ }
+ return result;
+ }
+
+
+ private static String generateWsKeyValue() {
+ byte[] keyBytes = new byte[16];
+ RANDOM.nextBytes(keyBytes);
+ return Base64.encodeBase64String(keyBytes);
+ }
+
+
+ private static ByteBuffer createRequest(URI uri, Map<String,List<String>> reqHeaders) {
+ ByteBuffer result = ByteBuffer.allocate(4 * 1024);
+
+ // Request line
+ result.put(GET_BYTES);
+ if (null == uri.getPath() || "".equals(uri.getPath())) {
+ result.put(ROOT_URI_BYTES);
+ } else {
+ result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1));
+ }
+ String query = uri.getRawQuery();
+ if (query != null) {
+ result.put((byte) '?');
+ result.put(query.getBytes(StandardCharsets.ISO_8859_1));
+ }
+ result.put(HTTP_VERSION_BYTES);
+
+ // Headers
+ for (Entry<String, List<String>> entry : reqHeaders.entrySet()) {
+ result = addHeader(result, entry.getKey(), entry.getValue());
+ }
+
+ // Terminating CRLF
+ result.put(CRLF);
+
+ result.flip();
+
+ return result;
+ }
+
+
+ private static ByteBuffer addHeader(ByteBuffer result, String key, List<String> values) {
+ if (values.isEmpty()) {
+ return result;
+ }
+
+ result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1));
+ result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1));
+ result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1));
+ result = putWithExpand(result, CRLF);
+
+ return result;
+ }
+
+
+ private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) {
+ if (bytes.length > input.remaining()) {
+ int newSize;
+ if (bytes.length > input.capacity()) {
+ newSize = 2 * bytes.length;
+ } else {
+ newSize = input.capacity() * 2;
+ }
+ ByteBuffer expanded = ByteBuffer.allocate(newSize);
+ input.flip();
+ expanded.put(input);
+ input = expanded;
+ }
+ return input.put(bytes);
+ }
+
+
+ /**
+ * Process response, blocking until HTTP response has been fully received.
+ * @throws ExecutionException
+ * @throws InterruptedException
+ * @throws DeploymentException
+ * @throws TimeoutException
+ */
+ private HttpResponse processResponse(ByteBuffer response,
+ AsyncChannelWrapper channel, long timeout) throws InterruptedException,
+ ExecutionException, DeploymentException, EOFException,
+ TimeoutException {
+
+ Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>();
+
+ int status = 0;
+ boolean readStatus = false;
+ boolean readHeaders = false;
+ String line = null;
+ while (!readHeaders) {
+ // On entering loop buffer will be empty and at the start of a new
+ // loop the buffer will have been fully read.
+ response.clear();
+ // Blocking read
+ Future<Integer> read = channel.read(response);
+ Integer bytesRead = read.get(timeout, TimeUnit.MILLISECONDS);
+ if (bytesRead.intValue() == -1) {
+ throw new EOFException();
+ }
+ response.flip();
+ while (response.hasRemaining() && !readHeaders) {
+ if (line == null) {
+ line = readLine(response);
+ } else {
+ line += readLine(response);
+ }
+ if ("\r\n".equals(line)) {
+ readHeaders = true;
+ } else if (line.endsWith("\r\n")) {
+ if (readStatus) {
+ parseHeaders(line, headers);
+ } else {
+ status = parseStatus(line);
+ readStatus = true;
+ }
+ line = null;
+ }
+ }
+ }
+
+ return new HttpResponse(status, new WsHandshakeResponse(headers));
+ }
+
+
+ private int parseStatus(String line) throws DeploymentException {
+ // This client only understands HTTP 1.
+ // RFC2616 is case specific
+ String[] parts = line.trim().split(" ");
+ // CONNECT for proxy may return a 1.0 response
+ if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.invalidStatus", line));
+ }
+ try {
+ return Integer.parseInt(parts[1]);
+ } catch (NumberFormatException nfe) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.invalidStatus", line));
+ }
+ }
+
+
+ private void parseHeaders(String line, Map<String,List<String>> headers) {
+ // Treat headers as single values by default.
+
+ int index = line.indexOf(':');
+ if (index == -1) {
+ log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line));
+ return;
+ }
+ // Header names are case insensitive so always use lower case
+ String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH);
+ // Multi-value headers are stored as a single header and the client is
+ // expected to handle splitting into individual values
+ String headerValue = line.substring(index + 1).trim();
+
+ List<String> values = headers.get(headerName);
+ if (values == null) {
+ values = new ArrayList<>(1);
+ headers.put(headerName, values);
+ }
+ values.add(headerValue);
+ }
+
+ private String readLine(ByteBuffer response) {
+ // All ISO-8859-1
+ StringBuilder sb = new StringBuilder();
+
+ char c = 0;
+ while (response.hasRemaining()) {
+ c = (char) response.get();
+ sb.append(c);
+ if (c == 10) {
+ break;
+ }
+ }
+
+ return sb.toString();
+ }
+
+
+ private SSLEngine createSSLEngine(Map<String,Object> userProperties, String host, int port)
+ throws DeploymentException {
+
+ try {
+ // See if a custom SSLContext has been provided
+ SSLContext sslContext =
+ (SSLContext) userProperties.get(Constants.SSL_CONTEXT_PROPERTY);
+
+ if (sslContext == null) {
+ // Create the SSL Context
+ sslContext = SSLContext.getInstance("TLS");
+
+ // Trust store
+ String sslTrustStoreValue =
+ (String) userProperties.get(Constants.SSL_TRUSTSTORE_PROPERTY);
+ if (sslTrustStoreValue != null) {
+ String sslTrustStorePwdValue = (String) userProperties.get(
+ Constants.SSL_TRUSTSTORE_PWD_PROPERTY);
+ if (sslTrustStorePwdValue == null) {
+ sslTrustStorePwdValue = Constants.SSL_TRUSTSTORE_PWD_DEFAULT;
+ }
+
+ File keyStoreFile = new File(sslTrustStoreValue);
+ KeyStore ks = KeyStore.getInstance("JKS");
+ try (InputStream is = new FileInputStream(keyStoreFile)) {
+ ks.load(is, sslTrustStorePwdValue.toCharArray());
+ }
+
+ TrustManagerFactory tmf = TrustManagerFactory.getInstance(
+ TrustManagerFactory.getDefaultAlgorithm());
+ tmf.init(ks);
+
+ sslContext.init(null, tmf.getTrustManagers(), null);
+ } else {
+ sslContext.init(null, null, null);
+ }
+ }
+
+ SSLEngine engine = sslContext.createSSLEngine(host, port);
+
+ String sslProtocolsValue =
+ (String) userProperties.get(Constants.SSL_PROTOCOLS_PROPERTY);
+ if (sslProtocolsValue != null) {
+ engine.setEnabledProtocols(sslProtocolsValue.split(","));
+ }
+
+ engine.setUseClientMode(true);
+
+ // Enable host verification
+ // Start with current settings (returns a copy)
+ SSLParameters sslParams = engine.getSSLParameters();
+ // Use HTTPS since WebSocket starts over HTTP(S)
+ sslParams.setEndpointIdentificationAlgorithm("HTTPS");
+ // Write the parameters back
+ engine.setSSLParameters(sslParams);
+
+ return engine;
+ } catch (Exception e) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.sslEngineFail"), e);
+ }
+ }
+
+
+ @Override
+ public long getDefaultMaxSessionIdleTimeout() {
+ return defaultMaxSessionIdleTimeout;
+ }
+
+
+ @Override
+ public void setDefaultMaxSessionIdleTimeout(long timeout) {
+ this.defaultMaxSessionIdleTimeout = timeout;
+ }
+
+
+ @Override
+ public int getDefaultMaxBinaryMessageBufferSize() {
+ return maxBinaryMessageBufferSize;
+ }
+
+
+ @Override
+ public void setDefaultMaxBinaryMessageBufferSize(int max) {
+ maxBinaryMessageBufferSize = max;
+ }
+
+
+ @Override
+ public int getDefaultMaxTextMessageBufferSize() {
+ return maxTextMessageBufferSize;
+ }
+
+
+ @Override
+ public void setDefaultMaxTextMessageBufferSize(int max) {
+ maxTextMessageBufferSize = max;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * Currently, this implementation does not support any extensions.
+ */
+ @Override
+ public Set<Extension> getInstalledExtensions() {
+ return Collections.emptySet();
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * The default value for this implementation is -1.
+ */
+ @Override
+ public long getDefaultAsyncSendTimeout() {
+ return defaultAsyncTimeout;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * The default value for this implementation is -1.
+ */
+ @Override
+ public void setAsyncSendTimeout(long timeout) {
+ this.defaultAsyncTimeout = timeout;
+ }
+
+
+ /**
+ * Cleans up the resources still in use by WebSocket sessions created from
+ * this container. This includes closing sessions and cancelling
+ * {@link Future}s associated with blocking read/writes.
+ */
+ public void destroy() {
+ CloseReason cr = new CloseReason(
+ CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown"));
+
+ for (WsSession session : sessions.keySet()) {
+ try {
+ session.close(cr);
+ } catch (IOException ioe) {
+ log.debug(sm.getString(
+ "wsWebSocketContainer.sessionCloseFail", session.getId()), ioe);
+ }
+ }
+
+ // Only unregister with AsyncChannelGroupUtil if this instance
+ // registered with it
+ if (asynchronousChannelGroup != null) {
+ synchronized (asynchronousChannelGroupLock) {
+ if (asynchronousChannelGroup != null) {
+ AsyncChannelGroupUtil.unregister();
+ asynchronousChannelGroup = null;
+ }
+ }
+ }
+ }
+
+
+ private AsynchronousChannelGroup getAsynchronousChannelGroup() {
+ // Use AsyncChannelGroupUtil to share a common group amongst all
+ // WebSocket clients
+ AsynchronousChannelGroup result = asynchronousChannelGroup;
+ if (result == null) {
+ synchronized (asynchronousChannelGroupLock) {
+ if (asynchronousChannelGroup == null) {
+ asynchronousChannelGroup = AsyncChannelGroupUtil.register();
+ }
+ result = asynchronousChannelGroup;
+ }
+ }
+ return result;
+ }
+
+
+ // ----------------------------------------------- BackgroundProcess methods
+
+ @Override
+ public void backgroundProcess() {
+ // This method gets called once a second.
+ backgroundProcessCount ++;
+ if (backgroundProcessCount >= processPeriod) {
+ backgroundProcessCount = 0;
+
+ for (WsSession wsSession : sessions.keySet()) {
+ wsSession.checkExpiration();
+ }
+ }
+
+ }
+
+
+ @Override
+ public void setProcessPeriod(int period) {
+ this.processPeriod = period;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * The default value is 10 which means session expirations are processed
+ * every 10 seconds.
+ */
+ @Override
+ public int getProcessPeriod() {
+ return processPeriod;
+ }
+
+
+ private static class HttpResponse {
+ private final int status;
+ private final HandshakeResponse handshakeResponse;
+
+ public HttpResponse(int status, HandshakeResponse handshakeResponse) {
+ this.status = status;
+ this.handshakeResponse = handshakeResponse;
+ }
+
+
+ public int getStatus() {
+ return status;
+ }
+
+
+ public HandshakeResponse getHandshakeResponse() {
+ return handshakeResponse;
+ }
+ }
+}