summaryrefslogtreecommitdiffhomepage
path: root/src/java/nginx
diff options
context:
space:
mode:
authorMax Romanov <max.romanov@nginx.com>2019-09-05 15:27:32 +0300
committerMax Romanov <max.romanov@nginx.com>2019-09-05 15:27:32 +0300
commit2b8cab1e2478547398ad9c2fe68e025c180cac54 (patch)
treed317fcf9ee52f0f8967116f531784ae533b0ae5a /src/java/nginx
parent3e23afb0d205e503f6cc7d852e34d07da9a5b7f7 (diff)
downloadunit-2b8cab1e2478547398ad9c2fe68e025c180cac54.tar.gz
unit-2b8cab1e2478547398ad9c2fe68e025c180cac54.tar.bz2
Java: introducing websocket support.
Diffstat (limited to 'src/java/nginx')
-rw-r--r--src/java/nginx/unit/Context.java139
-rw-r--r--src/java/nginx/unit/Request.java90
-rw-r--r--src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java151
-rw-r--r--src/java/nginx/unit/websocket/AsyncChannelWrapper.java47
-rw-r--r--src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java112
-rw-r--r--src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java578
-rw-r--r--src/java/nginx/unit/websocket/AuthenticationException.java35
-rw-r--r--src/java/nginx/unit/websocket/Authenticator.java71
-rw-r--r--src/java/nginx/unit/websocket/AuthenticatorFactory.java68
-rw-r--r--src/java/nginx/unit/websocket/BackgroundProcess.java26
-rw-r--r--src/java/nginx/unit/websocket/BackgroundProcessManager.java149
-rw-r--r--src/java/nginx/unit/websocket/BasicAuthenticator.java66
-rw-r--r--src/java/nginx/unit/websocket/Constants.java158
-rw-r--r--src/java/nginx/unit/websocket/DecoderEntry.java39
-rw-r--r--src/java/nginx/unit/websocket/DigestAuthenticator.java150
-rw-r--r--src/java/nginx/unit/websocket/FutureToSendHandler.java112
-rw-r--r--src/java/nginx/unit/websocket/LocalStrings.properties147
-rw-r--r--src/java/nginx/unit/websocket/MessageHandlerResult.java42
-rw-r--r--src/java/nginx/unit/websocket/MessageHandlerResultType.java23
-rw-r--r--src/java/nginx/unit/websocket/MessagePart.java83
-rw-r--r--src/java/nginx/unit/websocket/PerMessageDeflate.java476
-rw-r--r--src/java/nginx/unit/websocket/ReadBufferOverflowException.java34
-rw-r--r--src/java/nginx/unit/websocket/Transformation.java111
-rw-r--r--src/java/nginx/unit/websocket/TransformationFactory.java51
-rw-r--r--src/java/nginx/unit/websocket/TransformationResult.java37
-rw-r--r--src/java/nginx/unit/websocket/Util.java666
-rw-r--r--src/java/nginx/unit/websocket/WrappedMessageHandler.java25
-rw-r--r--src/java/nginx/unit/websocket/WsContainerProvider.java28
-rw-r--r--src/java/nginx/unit/websocket/WsExtension.java46
-rw-r--r--src/java/nginx/unit/websocket/WsExtensionParameter.java40
-rw-r--r--src/java/nginx/unit/websocket/WsFrameBase.java1010
-rw-r--r--src/java/nginx/unit/websocket/WsFrameClient.java228
-rw-r--r--src/java/nginx/unit/websocket/WsHandshakeResponse.java56
-rw-r--r--src/java/nginx/unit/websocket/WsIOException.java41
-rw-r--r--src/java/nginx/unit/websocket/WsPongMessage.java39
-rw-r--r--src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java79
-rw-r--r--src/java/nginx/unit/websocket/WsRemoteEndpointBase.java64
-rw-r--r--src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java76
-rw-r--r--src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java1234
-rw-r--r--src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java75
-rw-r--r--src/java/nginx/unit/websocket/WsSession.java1070
-rw-r--r--src/java/nginx/unit/websocket/WsWebSocketContainer.java1123
-rw-r--r--src/java/nginx/unit/websocket/pojo/Constants.java32
-rw-r--r--src/java/nginx/unit/websocket/pojo/LocalStrings.properties40
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java156
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java47
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java66
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java122
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java77
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java36
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java35
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java94
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java131
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java48
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java136
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java731
-rw-r--r--src/java/nginx/unit/websocket/pojo/PojoPathParam.java47
-rw-r--r--src/java/nginx/unit/websocket/pojo/package-info.java21
-rw-r--r--src/java/nginx/unit/websocket/server/Constants.java38
-rw-r--r--src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java88
-rw-r--r--src/java/nginx/unit/websocket/server/LocalStrings.properties43
-rw-r--r--src/java/nginx/unit/websocket/server/UpgradeUtil.java285
-rw-r--r--src/java/nginx/unit/websocket/server/UriTemplate.java177
-rw-r--r--src/java/nginx/unit/websocket/server/WsContextListener.java51
-rw-r--r--src/java/nginx/unit/websocket/server/WsFilter.java81
-rw-r--r--src/java/nginx/unit/websocket/server/WsHandshakeRequest.java196
-rw-r--r--src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java172
-rw-r--r--src/java/nginx/unit/websocket/server/WsMappingResult.java44
-rw-r--r--src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java84
-rw-r--r--src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java158
-rw-r--r--src/java/nginx/unit/websocket/server/WsSci.java145
-rw-r--r--src/java/nginx/unit/websocket/server/WsServerContainer.java470
-rw-r--r--src/java/nginx/unit/websocket/server/WsSessionListener.java36
-rw-r--r--src/java/nginx/unit/websocket/server/WsWriteTimeout.java128
-rw-r--r--src/java/nginx/unit/websocket/server/package-info.java21
75 files changed, 12875 insertions, 56 deletions
diff --git a/src/java/nginx/unit/Context.java b/src/java/nginx/unit/Context.java
index e1482903..6fcd6018 100644
--- a/src/java/nginx/unit/Context.java
+++ b/src/java/nginx/unit/Context.java
@@ -98,10 +98,14 @@ import javax.servlet.http.HttpSessionEvent;
import javax.servlet.http.HttpSessionIdListener;
import javax.servlet.http.HttpSessionListener;
+import javax.websocket.server.ServerEndpoint;
+
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import javax.xml.parsers.ParserConfigurationException;
+import nginx.unit.websocket.WsSession;
+
import org.eclipse.jetty.http.MimeTypes;
import org.w3c.dom.Document;
@@ -421,6 +425,9 @@ public class Context implements ServletContext, InitParams
loader_ = new AppClassLoader(urls,
Context.class.getClassLoader().getParent());
+ Class wsSession_class = WsSession.class;
+ trace("wsSession.test: " + WsSession.wsSession_test());
+
ClassLoader old = Thread.currentThread().getContextClassLoader();
Thread.currentThread().setContextClassLoader(loader_);
@@ -429,28 +436,30 @@ public class Context implements ServletContext, InitParams
addListener(listener_classname);
}
- ScanResult scan_res = null;
+ ClassGraph classgraph = new ClassGraph()
+ //.verbose()
+ .overrideClassLoaders(loader_)
+ .ignoreParentClassLoaders()
+ .enableClassInfo()
+ .enableAnnotationInfo()
+ //.enableSystemPackages()
+ .whitelistModules("javax.*")
+ //.enableAllInfo()
+ ;
- if (!metadata_complete_) {
- ClassGraph classgraph = new ClassGraph()
- //.verbose()
- .overrideClassLoaders(loader_)
- .ignoreParentClassLoaders()
- .enableClassInfo()
- .enableAnnotationInfo()
- //.enableSystemPackages()
- .whitelistModules("javax.*")
- //.enableAllInfo()
- ;
-
- String verbose = System.getProperty("nginx.unit.context.classgraph.verbose", "").trim();
-
- if (verbose.equals("true")) {
- classgraph.verbose();
- }
+ String verbose = System.getProperty("nginx.unit.context.classgraph.verbose", "").trim();
+
+ if (verbose.equals("true")) {
+ classgraph.verbose();
+ }
+
+ ScanResult scan_res = classgraph.scan();
+
+ javax.websocket.server.ServerEndpointConfig.Configurator.setDefault(new nginx.unit.websocket.server.DefaultServerEndpointConfigurator());
- scan_res = classgraph.scan();
+ loadInitializer(new nginx.unit.websocket.server.WsSci(), scan_res);
+ if (!metadata_complete_) {
loadInitializers(scan_res);
}
@@ -1471,54 +1480,61 @@ public class Context implements ServletContext, InitParams
ServiceLoader.load(ServletContainerInitializer.class, loader_);
for (ServletContainerInitializer sci : initializers) {
+ loadInitializer(sci, scan_res);
+ }
+ }
- trace("loadInitializers: initializer: " + sci.getClass().getName());
+ private void loadInitializer(ServletContainerInitializer sci, ScanResult scan_res)
+ {
+ trace("loadInitializer: initializer: " + sci.getClass().getName());
- HandlesTypes ann = sci.getClass().getAnnotation(HandlesTypes.class);
- if (ann == null) {
- trace("loadInitializers: no HandlesTypes annotation");
- continue;
- }
+ HandlesTypes ann = sci.getClass().getAnnotation(HandlesTypes.class);
+ if (ann == null) {
+ trace("loadInitializer: no HandlesTypes annotation");
+ return;
+ }
- Class<?>[] classes = ann.value();
- if (classes == null) {
- trace("loadInitializers: no handles classes");
- continue;
- }
+ Class<?>[] classes = ann.value();
+ if (classes == null) {
+ trace("loadInitializer: no handles classes");
+ return;
+ }
- Set<Class<?>> handles_classes = new HashSet<>();
+ Set<Class<?>> handles_classes = new HashSet<>();
- for (Class<?> c : classes) {
- trace("loadInitializers: find handles: " + c.getName());
+ for (Class<?> c : classes) {
+ trace("loadInitializer: find handles: " + c.getName());
- ClassInfoList handles = c.isInterface()
+ ClassInfoList handles =
+ c.isAnnotation()
+ ? scan_res.getClassesWithAnnotation(c.getName())
+ : c.isInterface()
? scan_res.getClassesImplementing(c.getName())
: scan_res.getSubclasses(c.getName());
- for (ClassInfo ci : handles) {
- if (ci.isInterface()
- || ci.isAnnotation()
- || ci.isAbstract())
- {
- continue;
- }
-
- trace("loadInitializers: handles class: " + ci.getName());
- handles_classes.add(ci.loadClass());
+ for (ClassInfo ci : handles) {
+ if (ci.isInterface()
+ || ci.isAnnotation()
+ || ci.isAbstract())
+ {
+ return;
}
- }
- if (handles_classes.isEmpty()) {
- trace("loadInitializers: no handles implementations");
- continue;
+ trace("loadInitializer: handles class: " + ci.getName());
+ handles_classes.add(ci.loadClass());
}
+ }
- try {
- sci.onStartup(handles_classes, this);
- metadata_complete_ = true;
- } catch(Exception e) {
- System.err.println("loadInitializers: exception caught: " + e.toString());
- }
+ if (handles_classes.isEmpty()) {
+ trace("loadInitializer: no handles implementations");
+ return;
+ }
+
+ try {
+ sci.onStartup(handles_classes, this);
+ metadata_complete_ = true;
+ } catch(Exception e) {
+ System.err.println("loadInitializer: exception caught: " + e.toString());
}
}
@@ -1691,6 +1707,21 @@ public class Context implements ServletContext, InitParams
listener_classnames_.add(ci.getName());
}
+
+
+ ClassInfoList endpoints = scan_res.getClassesWithAnnotation(ServerEndpoint.class.getName());
+
+ for (ClassInfo ci : endpoints) {
+ if (ci.isInterface()
+ || ci.isAnnotation()
+ || ci.isAbstract())
+ {
+ trace("scanClasses: skip server end point: " + ci.getName());
+ continue;
+ }
+
+ trace("scanClasses: server end point: " + ci.getName());
+ }
}
public void stop() throws IOException
diff --git a/src/java/nginx/unit/Request.java b/src/java/nginx/unit/Request.java
index 98584efe..335d7980 100644
--- a/src/java/nginx/unit/Request.java
+++ b/src/java/nginx/unit/Request.java
@@ -16,6 +16,7 @@ import java.lang.StringBuffer;
import java.net.URI;
import java.net.URISyntaxException;
+import java.nio.ByteBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
@@ -65,6 +66,9 @@ import org.eclipse.jetty.http.MultiPartFormInputStream;
import org.eclipse.jetty.http.HttpFields;
import org.eclipse.jetty.http.MimeTypes;
+import nginx.unit.websocket.WsSession;
+import nginx.unit.websocket.WsIOException;
+
public class Request implements HttpServletRequest, DynamicPathRequest
{
private final Context context;
@@ -114,6 +118,9 @@ public class Request implements HttpServletRequest, DynamicPathRequest
private boolean request_session_id_from_url = false;
private Session session = null;
+ private WsSession wsSession = null;
+ private boolean skip_close_ws = false;
+
private final ServletRequestAttributeListener attr_listener;
public static final String BARE = "nginx.unit.request.bare";
@@ -1203,11 +1210,30 @@ public class Request implements HttpServletRequest, DynamicPathRequest
public <T extends HttpUpgradeHandler> T upgrade(
Class<T> httpUpgradeHandlerClass) throws java.io.IOException, ServletException
{
- log("upgrade: " + httpUpgradeHandlerClass.getName());
+ trace("upgrade: " + httpUpgradeHandlerClass.getName());
- return null;
+ T handler;
+
+ try {
+ handler = httpUpgradeHandlerClass.getConstructor().newInstance();
+ } catch (Exception e) {
+ throw new ServletException(e);
+ }
+
+ upgrade(req_info_ptr);
+
+ return handler;
+ }
+
+ private static native void upgrade(long req_info_ptr);
+
+ public boolean isUpgrade()
+ {
+ return isUpgrade(req_info_ptr);
}
+ private static native boolean isUpgrade(long req_info_ptr);
+
@Override
public String changeSessionId()
{
@@ -1248,5 +1274,65 @@ public class Request implements HttpServletRequest, DynamicPathRequest
public static native void trace(long req_info_ptr, String msg, int msg_len);
private static native Response getResponse(long req_info_ptr);
+
+
+ public void setWsSession(WsSession s)
+ {
+ wsSession = s;
+ }
+
+ private void processWsFrame(ByteBuffer buf, byte opCode, boolean last)
+ throws IOException
+ {
+ trace("processWsFrame: " + opCode + ", [" + buf.position() + ", " + buf.limit() + "]");
+ try {
+ wsSession.processFrame(buf, opCode, last);
+ } catch (WsIOException e) {
+ wsSession.onClose(e.getCloseReason());
+ }
+ }
+
+ private void closeWsSession()
+ {
+ trace("closeWsSession");
+ skip_close_ws = true;
+
+ wsSession.onClose();
+ }
+
+ public void sendWsFrame(ByteBuffer payload, byte opCode, boolean last,
+ long timeoutExpiry) throws IOException
+ {
+ trace("sendWsFrame: " + opCode + ", [" + payload.position() +
+ ", " + payload.limit() + "]");
+
+ if (payload.isDirect()) {
+ sendWsFrame(req_info_ptr, payload, payload.position(),
+ payload.limit() - payload.position(), opCode, last);
+ } else {
+ sendWsFrame(req_info_ptr, payload.array(), payload.position(),
+ payload.limit() - payload.position(), opCode, last);
+ }
+ }
+
+ private static native void sendWsFrame(long req_info_ptr,
+ ByteBuffer buf, int pos, int len, byte opCode, boolean last);
+
+ private static native void sendWsFrame(long req_info_ptr,
+ byte[] arr, int pos, int len, byte opCode, boolean last);
+
+
+ public void closeWs()
+ {
+ if (skip_close_ws) {
+ return;
+ }
+
+ trace("closeWs");
+
+ closeWs(req_info_ptr);
+ }
+
+ private static native void closeWs(long req_info_ptr);
}
diff --git a/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java b/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java
new file mode 100644
index 00000000..147112c1
--- /dev/null
+++ b/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.nio.channels.AsynchronousChannelGroup;
+import java.security.AccessController;
+import java.security.PrivilegedAction;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.apache.tomcat.util.res.StringManager;
+import org.apache.tomcat.util.threads.ThreadPoolExecutor;
+
+/**
+ * This is a utility class that enables multiple {@link WsWebSocketContainer}
+ * instances to share a single {@link AsynchronousChannelGroup} while ensuring
+ * that the group is destroyed when no longer required.
+ */
+public class AsyncChannelGroupUtil {
+
+ private static final StringManager sm =
+ StringManager.getManager(AsyncChannelGroupUtil.class);
+
+ private static AsynchronousChannelGroup group = null;
+ private static int usageCount = 0;
+ private static final Object lock = new Object();
+
+
+ private AsyncChannelGroupUtil() {
+ // Hide the default constructor
+ }
+
+
+ public static AsynchronousChannelGroup register() {
+ synchronized (lock) {
+ if (usageCount == 0) {
+ group = createAsynchronousChannelGroup();
+ }
+ usageCount++;
+ return group;
+ }
+ }
+
+
+ public static void unregister() {
+ synchronized (lock) {
+ usageCount--;
+ if (usageCount == 0) {
+ group.shutdown();
+ group = null;
+ }
+ }
+ }
+
+
+ private static AsynchronousChannelGroup createAsynchronousChannelGroup() {
+ // Need to do this with the right thread context class loader else the
+ // first web app to call this will trigger a leak
+ ClassLoader original = Thread.currentThread().getContextClassLoader();
+
+ try {
+ Thread.currentThread().setContextClassLoader(
+ AsyncIOThreadFactory.class.getClassLoader());
+
+ // These are the same settings as the default
+ // AsynchronousChannelGroup
+ int initialSize = Runtime.getRuntime().availableProcessors();
+ ExecutorService executorService = new ThreadPoolExecutor(
+ 0,
+ Integer.MAX_VALUE,
+ Long.MAX_VALUE, TimeUnit.MILLISECONDS,
+ new SynchronousQueue<Runnable>(),
+ new AsyncIOThreadFactory());
+
+ try {
+ return AsynchronousChannelGroup.withCachedThreadPool(
+ executorService, initialSize);
+ } catch (IOException e) {
+ // No good reason for this to happen.
+ throw new IllegalStateException(sm.getString("asyncChannelGroup.createFail"));
+ }
+ } finally {
+ Thread.currentThread().setContextClassLoader(original);
+ }
+ }
+
+
+ private static class AsyncIOThreadFactory implements ThreadFactory {
+
+ static {
+ // Load NewThreadPrivilegedAction since newThread() will not be able
+ // to if called from an InnocuousThread.
+ // See https://bz.apache.org/bugzilla/show_bug.cgi?id=57490
+ NewThreadPrivilegedAction.load();
+ }
+
+
+ @Override
+ public Thread newThread(final Runnable r) {
+ // Create the new Thread within a doPrivileged block to ensure that
+ // the thread inherits the current ProtectionDomain which is
+ // essential to be able to use this with a Java Applet. See
+ // https://bz.apache.org/bugzilla/show_bug.cgi?id=57091
+ return AccessController.doPrivileged(new NewThreadPrivilegedAction(r));
+ }
+
+ // Non-anonymous class so that AsyncIOThreadFactory can load it
+ // explicitly
+ private static class NewThreadPrivilegedAction implements PrivilegedAction<Thread> {
+
+ private static AtomicInteger count = new AtomicInteger(0);
+
+ private final Runnable r;
+
+ public NewThreadPrivilegedAction(Runnable r) {
+ this.r = r;
+ }
+
+ @Override
+ public Thread run() {
+ Thread t = new Thread(r);
+ t.setName("WebSocketClient-AsyncIO-" + count.incrementAndGet());
+ t.setContextClassLoader(this.getClass().getClassLoader());
+ t.setDaemon(true);
+ return t;
+ }
+
+ private static void load() {
+ // NO-OP. Just provides a hook to enable the class to be loaded
+ }
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapper.java b/src/java/nginx/unit/websocket/AsyncChannelWrapper.java
new file mode 100644
index 00000000..060ae9cb
--- /dev/null
+++ b/src/java/nginx/unit/websocket/AsyncChannelWrapper.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.nio.ByteBuffer;
+import java.nio.channels.CompletionHandler;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+
+import javax.net.ssl.SSLException;
+
+/**
+ * This is a wrapper for a {@link java.nio.channels.AsynchronousSocketChannel}
+ * that limits the methods available thereby simplifying the process of
+ * implementing SSL/TLS support since there are fewer methods to intercept.
+ */
+public interface AsyncChannelWrapper {
+
+ Future<Integer> read(ByteBuffer dst);
+
+ <B,A extends B> void read(ByteBuffer dst, A attachment,
+ CompletionHandler<Integer,B> handler);
+
+ Future<Integer> write(ByteBuffer src);
+
+ <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length,
+ long timeout, TimeUnit unit, A attachment,
+ CompletionHandler<Long,B> handler);
+
+ void close();
+
+ Future<Void> handshake() throws SSLException;
+}
diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java
new file mode 100644
index 00000000..5b88bfe1
--- /dev/null
+++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.AsynchronousSocketChannel;
+import java.nio.channels.CompletionHandler;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+/**
+ * Generally, just passes calls straight to the wrapped
+ * {@link AsynchronousSocketChannel}. In some cases exceptions may be swallowed
+ * to save them being swallowed by the calling code.
+ */
+public class AsyncChannelWrapperNonSecure implements AsyncChannelWrapper {
+
+ private static final Future<Void> NOOP_FUTURE = new NoOpFuture();
+
+ private final AsynchronousSocketChannel socketChannel;
+
+ public AsyncChannelWrapperNonSecure(
+ AsynchronousSocketChannel socketChannel) {
+ this.socketChannel = socketChannel;
+ }
+
+ @Override
+ public Future<Integer> read(ByteBuffer dst) {
+ return socketChannel.read(dst);
+ }
+
+ @Override
+ public <B,A extends B> void read(ByteBuffer dst, A attachment,
+ CompletionHandler<Integer,B> handler) {
+ socketChannel.read(dst, attachment, handler);
+ }
+
+ @Override
+ public Future<Integer> write(ByteBuffer src) {
+ return socketChannel.write(src);
+ }
+
+ @Override
+ public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length,
+ long timeout, TimeUnit unit, A attachment,
+ CompletionHandler<Long,B> handler) {
+ socketChannel.write(
+ srcs, offset, length, timeout, unit, attachment, handler);
+ }
+
+ @Override
+ public void close() {
+ try {
+ socketChannel.close();
+ } catch (IOException e) {
+ // Ignore
+ }
+ }
+
+ @Override
+ public Future<Void> handshake() {
+ return NOOP_FUTURE;
+ }
+
+
+ private static final class NoOpFuture implements Future<Void> {
+
+ @Override
+ public boolean cancel(boolean mayInterruptIfRunning) {
+ return false;
+ }
+
+ @Override
+ public boolean isCancelled() {
+ return false;
+ }
+
+ @Override
+ public boolean isDone() {
+ return true;
+ }
+
+ @Override
+ public Void get() throws InterruptedException, ExecutionException {
+ return null;
+ }
+
+ @Override
+ public Void get(long timeout, TimeUnit unit)
+ throws InterruptedException, ExecutionException,
+ TimeoutException {
+ return null;
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java
new file mode 100644
index 00000000..21654487
--- /dev/null
+++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java
@@ -0,0 +1,578 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.AsynchronousSocketChannel;
+import java.nio.channels.CompletionHandler;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.ThreadFactory;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLEngineResult;
+import javax.net.ssl.SSLEngineResult.HandshakeStatus;
+import javax.net.ssl.SSLEngineResult.Status;
+import javax.net.ssl.SSLException;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.util.res.StringManager;
+
+/**
+ * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot
+ * more testing before it can be considered robust.
+ */
+public class AsyncChannelWrapperSecure implements AsyncChannelWrapper {
+
+ private final Log log =
+ LogFactory.getLog(AsyncChannelWrapperSecure.class);
+ private static final StringManager sm =
+ StringManager.getManager(AsyncChannelWrapperSecure.class);
+
+ private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921);
+ private final AsynchronousSocketChannel socketChannel;
+ private final SSLEngine sslEngine;
+ private final ByteBuffer socketReadBuffer;
+ private final ByteBuffer socketWriteBuffer;
+ // One thread for read, one for write
+ private final ExecutorService executor =
+ Executors.newFixedThreadPool(2, new SecureIOThreadFactory());
+ private AtomicBoolean writing = new AtomicBoolean(false);
+ private AtomicBoolean reading = new AtomicBoolean(false);
+
+ public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel,
+ SSLEngine sslEngine) {
+ this.socketChannel = socketChannel;
+ this.sslEngine = sslEngine;
+
+ int socketBufferSize = sslEngine.getSession().getPacketBufferSize();
+ socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize);
+ socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize);
+ }
+
+ @Override
+ public Future<Integer> read(ByteBuffer dst) {
+ WrapperFuture<Integer,Void> future = new WrapperFuture<>();
+
+ if (!reading.compareAndSet(false, true)) {
+ throw new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.concurrentRead"));
+ }
+
+ ReadTask readTask = new ReadTask(dst, future);
+
+ executor.execute(readTask);
+
+ return future;
+ }
+
+ @Override
+ public <B,A extends B> void read(ByteBuffer dst, A attachment,
+ CompletionHandler<Integer,B> handler) {
+
+ WrapperFuture<Integer,B> future =
+ new WrapperFuture<>(handler, attachment);
+
+ if (!reading.compareAndSet(false, true)) {
+ throw new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.concurrentRead"));
+ }
+
+ ReadTask readTask = new ReadTask(dst, future);
+
+ executor.execute(readTask);
+ }
+
+ @Override
+ public Future<Integer> write(ByteBuffer src) {
+
+ WrapperFuture<Long,Void> inner = new WrapperFuture<>();
+
+ if (!writing.compareAndSet(false, true)) {
+ throw new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.concurrentWrite"));
+ }
+
+ WriteTask writeTask =
+ new WriteTask(new ByteBuffer[] {src}, 0, 1, inner);
+
+ executor.execute(writeTask);
+
+ Future<Integer> future = new LongToIntegerFuture(inner);
+ return future;
+ }
+
+ @Override
+ public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length,
+ long timeout, TimeUnit unit, A attachment,
+ CompletionHandler<Long,B> handler) {
+
+ WrapperFuture<Long,B> future =
+ new WrapperFuture<>(handler, attachment);
+
+ if (!writing.compareAndSet(false, true)) {
+ throw new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.concurrentWrite"));
+ }
+
+ WriteTask writeTask = new WriteTask(srcs, offset, length, future);
+
+ executor.execute(writeTask);
+ }
+
+ @Override
+ public void close() {
+ try {
+ socketChannel.close();
+ } catch (IOException e) {
+ log.info(sm.getString("asyncChannelWrapperSecure.closeFail"));
+ }
+ executor.shutdownNow();
+ }
+
+ @Override
+ public Future<Void> handshake() throws SSLException {
+
+ WrapperFuture<Void,Void> wFuture = new WrapperFuture<>();
+
+ Thread t = new WebSocketSslHandshakeThread(wFuture);
+ t.start();
+
+ return wFuture;
+ }
+
+
+ private class WriteTask implements Runnable {
+
+ private final ByteBuffer[] srcs;
+ private final int offset;
+ private final int length;
+ private final WrapperFuture<Long,?> future;
+
+ public WriteTask(ByteBuffer[] srcs, int offset, int length,
+ WrapperFuture<Long,?> future) {
+ this.srcs = srcs;
+ this.future = future;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public void run() {
+ long written = 0;
+
+ try {
+ for (int i = offset; i < offset + length; i++) {
+ ByteBuffer src = srcs[i];
+ while (src.hasRemaining()) {
+ socketWriteBuffer.clear();
+
+ // Encrypt the data
+ SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer);
+ written += r.bytesConsumed();
+ Status s = r.getStatus();
+
+ if (s == Status.OK || s == Status.BUFFER_OVERFLOW) {
+ // Need to write out the bytes and may need to read from
+ // the source again to empty it
+ } else {
+ // Status.BUFFER_UNDERFLOW - only happens on unwrap
+ // Status.CLOSED - unexpected
+ throw new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.statusWrap"));
+ }
+
+ // Check for tasks
+ if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
+ Runnable runnable = sslEngine.getDelegatedTask();
+ while (runnable != null) {
+ runnable.run();
+ runnable = sslEngine.getDelegatedTask();
+ }
+ }
+
+ socketWriteBuffer.flip();
+
+ // Do the write
+ int toWrite = r.bytesProduced();
+ while (toWrite > 0) {
+ Future<Integer> f =
+ socketChannel.write(socketWriteBuffer);
+ Integer socketWrite = f.get();
+ toWrite -= socketWrite.intValue();
+ }
+ }
+ }
+
+
+ if (writing.compareAndSet(true, false)) {
+ future.complete(Long.valueOf(written));
+ } else {
+ future.fail(new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.wrongStateWrite")));
+ }
+ } catch (Exception e) {
+ writing.set(false);
+ future.fail(e);
+ }
+ }
+ }
+
+
+ private class ReadTask implements Runnable {
+
+ private final ByteBuffer dest;
+ private final WrapperFuture<Integer,?> future;
+
+ public ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future) {
+ this.dest = dest;
+ this.future = future;
+ }
+
+ @Override
+ public void run() {
+ int read = 0;
+
+ boolean forceRead = false;
+
+ try {
+ while (read == 0) {
+ socketReadBuffer.compact();
+
+ if (forceRead) {
+ forceRead = false;
+ Future<Integer> f = socketChannel.read(socketReadBuffer);
+ Integer socketRead = f.get();
+ if (socketRead.intValue() == -1) {
+ throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof"));
+ }
+ }
+
+ socketReadBuffer.flip();
+
+ if (socketReadBuffer.hasRemaining()) {
+ // Decrypt the data in the buffer
+ SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest);
+ read += r.bytesProduced();
+ Status s = r.getStatus();
+
+ if (s == Status.OK) {
+ // Bytes available for reading and there may be
+ // sufficient data in the socketReadBuffer to
+ // support further reads without reading from the
+ // socket
+ } else if (s == Status.BUFFER_UNDERFLOW) {
+ // There is partial data in the socketReadBuffer
+ if (read == 0) {
+ // Need more data before the partial data can be
+ // processed and some output generated
+ forceRead = true;
+ }
+ // else return the data we have and deal with the
+ // partial data on the next read
+ } else if (s == Status.BUFFER_OVERFLOW) {
+ // Not enough space in the destination buffer to
+ // store all of the data. We could use a bytes read
+ // value of -bufferSizeRequired to signal the new
+ // buffer size required but an explicit exception is
+ // clearer.
+ if (reading.compareAndSet(true, false)) {
+ throw new ReadBufferOverflowException(sslEngine.
+ getSession().getApplicationBufferSize());
+ } else {
+ future.fail(new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.wrongStateRead")));
+ }
+ } else {
+ // Status.CLOSED - unexpected
+ throw new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.statusUnwrap"));
+ }
+
+ // Check for tasks
+ if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
+ Runnable runnable = sslEngine.getDelegatedTask();
+ while (runnable != null) {
+ runnable.run();
+ runnable = sslEngine.getDelegatedTask();
+ }
+ }
+ } else {
+ forceRead = true;
+ }
+ }
+
+
+ if (reading.compareAndSet(true, false)) {
+ future.complete(Integer.valueOf(read));
+ } else {
+ future.fail(new IllegalStateException(sm.getString(
+ "asyncChannelWrapperSecure.wrongStateRead")));
+ }
+ } catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException |
+ ExecutionException | InterruptedException e) {
+ reading.set(false);
+ future.fail(e);
+ }
+ }
+ }
+
+
+ private class WebSocketSslHandshakeThread extends Thread {
+
+ private final WrapperFuture<Void,Void> hFuture;
+
+ private HandshakeStatus handshakeStatus;
+ private Status resultStatus;
+
+ public WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture) {
+ this.hFuture = hFuture;
+ }
+
+ @Override
+ public void run() {
+ try {
+ sslEngine.beginHandshake();
+ // So the first compact does the right thing
+ socketReadBuffer.position(socketReadBuffer.limit());
+
+ handshakeStatus = sslEngine.getHandshakeStatus();
+ resultStatus = Status.OK;
+
+ boolean handshaking = true;
+
+ while(handshaking) {
+ switch (handshakeStatus) {
+ case NEED_WRAP: {
+ socketWriteBuffer.clear();
+ SSLEngineResult r =
+ sslEngine.wrap(DUMMY, socketWriteBuffer);
+ checkResult(r, true);
+ socketWriteBuffer.flip();
+ Future<Integer> fWrite =
+ socketChannel.write(socketWriteBuffer);
+ fWrite.get();
+ break;
+ }
+ case NEED_UNWRAP: {
+ socketReadBuffer.compact();
+ if (socketReadBuffer.position() == 0 ||
+ resultStatus == Status.BUFFER_UNDERFLOW) {
+ Future<Integer> fRead =
+ socketChannel.read(socketReadBuffer);
+ fRead.get();
+ }
+ socketReadBuffer.flip();
+ SSLEngineResult r =
+ sslEngine.unwrap(socketReadBuffer, DUMMY);
+ checkResult(r, false);
+ break;
+ }
+ case NEED_TASK: {
+ Runnable r = null;
+ while ((r = sslEngine.getDelegatedTask()) != null) {
+ r.run();
+ }
+ handshakeStatus = sslEngine.getHandshakeStatus();
+ break;
+ }
+ case FINISHED: {
+ handshaking = false;
+ break;
+ }
+ case NOT_HANDSHAKING: {
+ throw new SSLException(
+ sm.getString("asyncChannelWrapperSecure.notHandshaking"));
+ }
+ }
+ }
+ } catch (Exception e) {
+ hFuture.fail(e);
+ return;
+ }
+
+ hFuture.complete(null);
+ }
+
+ private void checkResult(SSLEngineResult result, boolean wrap)
+ throws SSLException {
+
+ handshakeStatus = result.getHandshakeStatus();
+ resultStatus = result.getStatus();
+
+ if (resultStatus != Status.OK &&
+ (wrap || resultStatus != Status.BUFFER_UNDERFLOW)) {
+ throw new SSLException(
+ sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus));
+ }
+ if (wrap && result.bytesConsumed() != 0) {
+ throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap"));
+ }
+ if (!wrap && result.bytesProduced() != 0) {
+ throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap"));
+ }
+ }
+ }
+
+
+ private static class WrapperFuture<T,A> implements Future<T> {
+
+ private final CompletionHandler<T,A> handler;
+ private final A attachment;
+
+ private volatile T result = null;
+ private volatile Throwable throwable = null;
+ private CountDownLatch completionLatch = new CountDownLatch(1);
+
+ public WrapperFuture() {
+ this(null, null);
+ }
+
+ public WrapperFuture(CompletionHandler<T,A> handler, A attachment) {
+ this.handler = handler;
+ this.attachment = attachment;
+ }
+
+ public void complete(T result) {
+ this.result = result;
+ completionLatch.countDown();
+ if (handler != null) {
+ handler.completed(result, attachment);
+ }
+ }
+
+ public void fail(Throwable t) {
+ throwable = t;
+ completionLatch.countDown();
+ if (handler != null) {
+ handler.failed(throwable, attachment);
+ }
+ }
+
+ @Override
+ public final boolean cancel(boolean mayInterruptIfRunning) {
+ // Could support cancellation by closing the connection
+ return false;
+ }
+
+ @Override
+ public final boolean isCancelled() {
+ // Could support cancellation by closing the connection
+ return false;
+ }
+
+ @Override
+ public final boolean isDone() {
+ return completionLatch.getCount() > 0;
+ }
+
+ @Override
+ public T get() throws InterruptedException, ExecutionException {
+ completionLatch.await();
+ if (throwable != null) {
+ throw new ExecutionException(throwable);
+ }
+ return result;
+ }
+
+ @Override
+ public T get(long timeout, TimeUnit unit)
+ throws InterruptedException, ExecutionException,
+ TimeoutException {
+ boolean latchResult = completionLatch.await(timeout, unit);
+ if (latchResult == false) {
+ throw new TimeoutException();
+ }
+ if (throwable != null) {
+ throw new ExecutionException(throwable);
+ }
+ return result;
+ }
+ }
+
+ private static final class LongToIntegerFuture implements Future<Integer> {
+
+ private final Future<Long> wrapped;
+
+ public LongToIntegerFuture(Future<Long> wrapped) {
+ this.wrapped = wrapped;
+ }
+
+ @Override
+ public boolean cancel(boolean mayInterruptIfRunning) {
+ return wrapped.cancel(mayInterruptIfRunning);
+ }
+
+ @Override
+ public boolean isCancelled() {
+ return wrapped.isCancelled();
+ }
+
+ @Override
+ public boolean isDone() {
+ return wrapped.isDone();
+ }
+
+ @Override
+ public Integer get() throws InterruptedException, ExecutionException {
+ Long result = wrapped.get();
+ if (result.longValue() > Integer.MAX_VALUE) {
+ throw new ExecutionException(sm.getString(
+ "asyncChannelWrapperSecure.tooBig", result), null);
+ }
+ return Integer.valueOf(result.intValue());
+ }
+
+ @Override
+ public Integer get(long timeout, TimeUnit unit)
+ throws InterruptedException, ExecutionException,
+ TimeoutException {
+ Long result = wrapped.get(timeout, unit);
+ if (result.longValue() > Integer.MAX_VALUE) {
+ throw new ExecutionException(sm.getString(
+ "asyncChannelWrapperSecure.tooBig", result), null);
+ }
+ return Integer.valueOf(result.intValue());
+ }
+ }
+
+
+ private static class SecureIOThreadFactory implements ThreadFactory {
+
+ private AtomicInteger count = new AtomicInteger(0);
+
+ @Override
+ public Thread newThread(Runnable r) {
+ Thread t = new Thread(r);
+ t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet());
+ // No need to set the context class loader. The threads will be
+ // cleaned up when the connection is closed.
+ t.setDaemon(true);
+ return t;
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/AuthenticationException.java b/src/java/nginx/unit/websocket/AuthenticationException.java
new file mode 100644
index 00000000..001f1829
--- /dev/null
+++ b/src/java/nginx/unit/websocket/AuthenticationException.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+/**
+ * Exception thrown on authentication error connecting to a remote
+ * websocket endpoint.
+ */
+public class AuthenticationException extends Exception {
+
+ private static final long serialVersionUID = 5709887412240096441L;
+
+ /**
+ * Create authentication exception.
+ * @param message the error message
+ */
+ public AuthenticationException(String message) {
+ super(message);
+ }
+
+}
diff --git a/src/java/nginx/unit/websocket/Authenticator.java b/src/java/nginx/unit/websocket/Authenticator.java
new file mode 100644
index 00000000..87b3ce6d
--- /dev/null
+++ b/src/java/nginx/unit/websocket/Authenticator.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+/**
+ * Base class for the authentication methods used by the websocket client.
+ */
+public abstract class Authenticator {
+ private static final Pattern pattern = Pattern
+ .compile("(\\w+)\\s*=\\s*(\"([^\"]+)\"|([^,=\"]+))\\s*,?");
+
+ /**
+ * Generate the authentication header that will be sent to the server.
+ * @param requestUri The request URI
+ * @param WWWAuthenticate The server auth challenge
+ * @param UserProperties The user information
+ * @return The auth header
+ * @throws AuthenticationException When an error occurs
+ */
+ public abstract String getAuthorization(String requestUri, String WWWAuthenticate,
+ Map<String, Object> UserProperties) throws AuthenticationException;
+
+ /**
+ * Get the authentication method.
+ * @return the auth scheme
+ */
+ public abstract String getSchemeName();
+
+ /**
+ * Utility method to parse the authentication header.
+ * @param WWWAuthenticate The server auth challenge
+ * @return the parsed header
+ */
+ public Map<String, String> parseWWWAuthenticateHeader(String WWWAuthenticate) {
+
+ Matcher m = pattern.matcher(WWWAuthenticate);
+ Map<String, String> challenge = new HashMap<>();
+
+ while (m.find()) {
+ String key = m.group(1);
+ String qtedValue = m.group(3);
+ String value = m.group(4);
+
+ challenge.put(key, qtedValue != null ? qtedValue : value);
+
+ }
+
+ return challenge;
+
+ }
+
+}
diff --git a/src/java/nginx/unit/websocket/AuthenticatorFactory.java b/src/java/nginx/unit/websocket/AuthenticatorFactory.java
new file mode 100644
index 00000000..7d46d7f9
--- /dev/null
+++ b/src/java/nginx/unit/websocket/AuthenticatorFactory.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.util.Iterator;
+import java.util.ServiceLoader;
+
+/**
+ * Utility method to return the appropriate authenticator according to
+ * the scheme that the server uses.
+ */
+public class AuthenticatorFactory {
+
+ /**
+ * Return a new authenticator instance.
+ * @param authScheme The scheme used
+ * @return the authenticator
+ */
+ public static Authenticator getAuthenticator(String authScheme) {
+
+ Authenticator auth = null;
+ switch (authScheme.toLowerCase()) {
+
+ case BasicAuthenticator.schemeName:
+ auth = new BasicAuthenticator();
+ break;
+
+ case DigestAuthenticator.schemeName:
+ auth = new DigestAuthenticator();
+ break;
+
+ default:
+ auth = loadAuthenticators(authScheme);
+ break;
+ }
+
+ return auth;
+
+ }
+
+ private static Authenticator loadAuthenticators(String authScheme) {
+ ServiceLoader<Authenticator> serviceLoader = ServiceLoader.load(Authenticator.class);
+ Iterator<Authenticator> auths = serviceLoader.iterator();
+
+ while (auths.hasNext()) {
+ Authenticator auth = auths.next();
+ if (auth.getSchemeName().equalsIgnoreCase(authScheme))
+ return auth;
+ }
+
+ return null;
+ }
+
+}
diff --git a/src/java/nginx/unit/websocket/BackgroundProcess.java b/src/java/nginx/unit/websocket/BackgroundProcess.java
new file mode 100644
index 00000000..0d2e1288
--- /dev/null
+++ b/src/java/nginx/unit/websocket/BackgroundProcess.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+public interface BackgroundProcess {
+
+ void backgroundProcess();
+
+ void setProcessPeriod(int period);
+
+ int getProcessPeriod();
+}
diff --git a/src/java/nginx/unit/websocket/BackgroundProcessManager.java b/src/java/nginx/unit/websocket/BackgroundProcessManager.java
new file mode 100644
index 00000000..d8b1b950
--- /dev/null
+++ b/src/java/nginx/unit/websocket/BackgroundProcessManager.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.util.HashSet;
+import java.util.Set;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.util.ExceptionUtils;
+import org.apache.tomcat.util.res.StringManager;
+
+/**
+ * Provides a background processing mechanism that triggers roughly once a
+ * second. The class maintains a thread that only runs when there is at least
+ * one instance of {@link BackgroundProcess} registered.
+ */
+public class BackgroundProcessManager {
+
+ private final Log log =
+ LogFactory.getLog(BackgroundProcessManager.class);
+ private static final StringManager sm =
+ StringManager.getManager(BackgroundProcessManager.class);
+ private static final BackgroundProcessManager instance;
+
+
+ static {
+ instance = new BackgroundProcessManager();
+ }
+
+
+ public static BackgroundProcessManager getInstance() {
+ return instance;
+ }
+
+ private final Set<BackgroundProcess> processes = new HashSet<>();
+ private final Object processesLock = new Object();
+ private WsBackgroundThread wsBackgroundThread = null;
+
+ private BackgroundProcessManager() {
+ // Hide default constructor
+ }
+
+
+ public void register(BackgroundProcess process) {
+ synchronized (processesLock) {
+ if (processes.size() == 0) {
+ wsBackgroundThread = new WsBackgroundThread(this);
+ wsBackgroundThread.setContextClassLoader(
+ this.getClass().getClassLoader());
+ wsBackgroundThread.setDaemon(true);
+ wsBackgroundThread.start();
+ }
+ processes.add(process);
+ }
+ }
+
+
+ public void unregister(BackgroundProcess process) {
+ synchronized (processesLock) {
+ processes.remove(process);
+ if (wsBackgroundThread != null && processes.size() == 0) {
+ wsBackgroundThread.halt();
+ wsBackgroundThread = null;
+ }
+ }
+ }
+
+
+ private void process() {
+ Set<BackgroundProcess> currentProcesses = new HashSet<>();
+ synchronized (processesLock) {
+ currentProcesses.addAll(processes);
+ }
+ for (BackgroundProcess process : currentProcesses) {
+ try {
+ process.backgroundProcess();
+ } catch (Throwable t) {
+ ExceptionUtils.handleThrowable(t);
+ log.error(sm.getString(
+ "backgroundProcessManager.processFailed"), t);
+ }
+ }
+ }
+
+
+ /*
+ * For unit testing.
+ */
+ int getProcessCount() {
+ synchronized (processesLock) {
+ return processes.size();
+ }
+ }
+
+
+ void shutdown() {
+ synchronized (processesLock) {
+ processes.clear();
+ if (wsBackgroundThread != null) {
+ wsBackgroundThread.halt();
+ wsBackgroundThread = null;
+ }
+ }
+ }
+
+
+ private static class WsBackgroundThread extends Thread {
+
+ private final BackgroundProcessManager manager;
+ private volatile boolean running = true;
+
+ public WsBackgroundThread(BackgroundProcessManager manager) {
+ setName("WebSocket background processing");
+ this.manager = manager;
+ }
+
+ @Override
+ public void run() {
+ while (running) {
+ try {
+ Thread.sleep(1000);
+ } catch (InterruptedException e) {
+ // Ignore
+ }
+ manager.process();
+ }
+ }
+
+ public void halt() {
+ setName("WebSocket background processing - stopping");
+ running = false;
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/BasicAuthenticator.java b/src/java/nginx/unit/websocket/BasicAuthenticator.java
new file mode 100644
index 00000000..1b1a6b83
--- /dev/null
+++ b/src/java/nginx/unit/websocket/BasicAuthenticator.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
+import java.util.Base64;
+import java.util.Map;
+
+/**
+ * Authenticator supporting the BASIC auth method.
+ */
+public class BasicAuthenticator extends Authenticator {
+
+ public static final String schemeName = "basic";
+ public static final String charsetparam = "charset";
+
+ @Override
+ public String getAuthorization(String requestUri, String WWWAuthenticate,
+ Map<String, Object> userProperties) throws AuthenticationException {
+
+ String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME);
+ String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD);
+
+ if (userName == null || password == null) {
+ throw new AuthenticationException(
+ "Failed to perform Basic authentication due to missing user/password");
+ }
+
+ Map<String, String> wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate);
+
+ String userPass = userName + ":" + password;
+ Charset charset;
+
+ if (wwwAuthenticate.get(charsetparam) != null
+ && wwwAuthenticate.get(charsetparam).equalsIgnoreCase("UTF-8")) {
+ charset = StandardCharsets.UTF_8;
+ } else {
+ charset = StandardCharsets.ISO_8859_1;
+ }
+
+ String base64 = Base64.getEncoder().encodeToString(userPass.getBytes(charset));
+
+ return " Basic " + base64;
+ }
+
+ @Override
+ public String getSchemeName() {
+ return schemeName;
+ }
+
+}
diff --git a/src/java/nginx/unit/websocket/Constants.java b/src/java/nginx/unit/websocket/Constants.java
new file mode 100644
index 00000000..38b22fe0
--- /dev/null
+++ b/src/java/nginx/unit/websocket/Constants.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+import javax.websocket.Extension;
+
+/**
+ * Internal implementation constants.
+ */
+public class Constants {
+
+ // OP Codes
+ public static final byte OPCODE_CONTINUATION = 0x00;
+ public static final byte OPCODE_TEXT = 0x01;
+ public static final byte OPCODE_BINARY = 0x02;
+ public static final byte OPCODE_CLOSE = 0x08;
+ public static final byte OPCODE_PING = 0x09;
+ public static final byte OPCODE_PONG = 0x0A;
+
+ // Internal OP Codes
+ // RFC 6455 limits OP Codes to 4 bits so these should never clash
+ // Always set bit 4 so these will be treated as control codes
+ static final byte INTERNAL_OPCODE_FLUSH = 0x18;
+
+ // Buffers
+ static final int DEFAULT_BUFFER_SIZE = Integer.getInteger(
+ "nginx.unit.websocket.DEFAULT_BUFFER_SIZE", 8 * 1024)
+ .intValue();
+
+ // Client connection
+ /**
+ * Property name to set to configure the value that is passed to
+ * {@link javax.net.ssl.SSLEngine#setEnabledProtocols(String[])}. The value
+ * should be a comma separated string.
+ */
+ public static final String SSL_PROTOCOLS_PROPERTY =
+ "nginx.unit.websocket.SSL_PROTOCOLS";
+ public static final String SSL_TRUSTSTORE_PROPERTY =
+ "nginx.unit.websocket.SSL_TRUSTSTORE";
+ public static final String SSL_TRUSTSTORE_PWD_PROPERTY =
+ "nginx.unit.websocket.SSL_TRUSTSTORE_PWD";
+ public static final String SSL_TRUSTSTORE_PWD_DEFAULT = "changeit";
+ /**
+ * Property name to set to configure used SSLContext. The value should be an
+ * instance of SSLContext. If this property is present, the SSL_TRUSTSTORE*
+ * properties are ignored.
+ */
+ public static final String SSL_CONTEXT_PROPERTY =
+ "nginx.unit.websocket.SSL_CONTEXT";
+ /**
+ * Property name to set to configure the timeout (in milliseconds) when
+ * establishing a WebSocket connection to server. The default is
+ * {@link #IO_TIMEOUT_MS_DEFAULT}.
+ */
+ public static final String IO_TIMEOUT_MS_PROPERTY =
+ "nginx.unit.websocket.IO_TIMEOUT_MS";
+ public static final long IO_TIMEOUT_MS_DEFAULT = 5000;
+
+ // RFC 2068 recommended a limit of 5
+ // Most browsers have a default limit of 20
+ public static final String MAX_REDIRECTIONS_PROPERTY =
+ "nginx.unit.websocket.MAX_REDIRECTIONS";
+ public static final int MAX_REDIRECTIONS_DEFAULT = 20;
+
+ // HTTP upgrade header names and values
+ public static final String HOST_HEADER_NAME = "Host";
+ public static final String UPGRADE_HEADER_NAME = "Upgrade";
+ public static final String UPGRADE_HEADER_VALUE = "websocket";
+ public static final String ORIGIN_HEADER_NAME = "Origin";
+ public static final String CONNECTION_HEADER_NAME = "Connection";
+ public static final String CONNECTION_HEADER_VALUE = "upgrade";
+ public static final String LOCATION_HEADER_NAME = "Location";
+ public static final String AUTHORIZATION_HEADER_NAME = "Authorization";
+ public static final String WWW_AUTHENTICATE_HEADER_NAME = "WWW-Authenticate";
+ public static final String WS_VERSION_HEADER_NAME = "Sec-WebSocket-Version";
+ public static final String WS_VERSION_HEADER_VALUE = "13";
+ public static final String WS_KEY_HEADER_NAME = "Sec-WebSocket-Key";
+ public static final String WS_PROTOCOL_HEADER_NAME = "Sec-WebSocket-Protocol";
+ public static final String WS_EXTENSIONS_HEADER_NAME = "Sec-WebSocket-Extensions";
+
+ /// HTTP redirection status codes
+ public static final int MULTIPLE_CHOICES = 300;
+ public static final int MOVED_PERMANENTLY = 301;
+ public static final int FOUND = 302;
+ public static final int SEE_OTHER = 303;
+ public static final int USE_PROXY = 305;
+ public static final int TEMPORARY_REDIRECT = 307;
+
+ // Configuration for Origin header in client
+ static final String DEFAULT_ORIGIN_HEADER_VALUE =
+ System.getProperty("nginx.unit.websocket.DEFAULT_ORIGIN_HEADER_VALUE");
+
+ // Configuration for blocking sends
+ public static final String BLOCKING_SEND_TIMEOUT_PROPERTY =
+ "nginx.unit.websocket.BLOCKING_SEND_TIMEOUT";
+ // Milliseconds so this is 20 seconds
+ public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000;
+
+ // Configuration for background processing checks intervals
+ static final int DEFAULT_PROCESS_PERIOD = Integer.getInteger(
+ "nginx.unit.websocket.DEFAULT_PROCESS_PERIOD", 10)
+ .intValue();
+
+ public static final String WS_AUTHENTICATION_USER_NAME = "nginx.unit.websocket.WS_AUTHENTICATION_USER_NAME";
+ public static final String WS_AUTHENTICATION_PASSWORD = "nginx.unit.websocket.WS_AUTHENTICATION_PASSWORD";
+
+ /* Configuration for extensions
+ * Note: These options are primarily present to enable this implementation
+ * to pass compliance tests. They are expected to be removed once
+ * the WebSocket API includes a mechanism for adding custom extensions
+ * and disabling built-in extensions.
+ */
+ static final boolean DISABLE_BUILTIN_EXTENSIONS =
+ Boolean.getBoolean("nginx.unit.websocket.DISABLE_BUILTIN_EXTENSIONS");
+ static final boolean ALLOW_UNSUPPORTED_EXTENSIONS =
+ Boolean.getBoolean("nginx.unit.websocket.ALLOW_UNSUPPORTED_EXTENSIONS");
+
+ // Configuration for stream behavior
+ static final boolean STREAMS_DROP_EMPTY_MESSAGES =
+ Boolean.getBoolean("nginx.unit.websocket.STREAMS_DROP_EMPTY_MESSAGES");
+
+ public static final boolean STRICT_SPEC_COMPLIANCE =
+ Boolean.getBoolean("nginx.unit.websocket.STRICT_SPEC_COMPLIANCE");
+
+ public static final List<Extension> INSTALLED_EXTENSIONS;
+
+ static {
+ if (DISABLE_BUILTIN_EXTENSIONS) {
+ INSTALLED_EXTENSIONS = Collections.unmodifiableList(new ArrayList<Extension>());
+ } else {
+ List<Extension> installed = new ArrayList<>(1);
+ installed.add(new WsExtension("permessage-deflate"));
+ INSTALLED_EXTENSIONS = Collections.unmodifiableList(installed);
+ }
+ }
+
+ private Constants() {
+ // Hide default constructor
+ }
+}
diff --git a/src/java/nginx/unit/websocket/DecoderEntry.java b/src/java/nginx/unit/websocket/DecoderEntry.java
new file mode 100644
index 00000000..36112ef4
--- /dev/null
+++ b/src/java/nginx/unit/websocket/DecoderEntry.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import javax.websocket.Decoder;
+
+public class DecoderEntry {
+
+ private final Class<?> clazz;
+ private final Class<? extends Decoder> decoderClazz;
+
+ public DecoderEntry(Class<?> clazz,
+ Class<? extends Decoder> decoderClazz) {
+ this.clazz = clazz;
+ this.decoderClazz = decoderClazz;
+ }
+
+ public Class<?> getClazz() {
+ return clazz;
+ }
+
+ public Class<? extends Decoder> getDecoderClazz() {
+ return decoderClazz;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/DigestAuthenticator.java b/src/java/nginx/unit/websocket/DigestAuthenticator.java
new file mode 100644
index 00000000..9530c303
--- /dev/null
+++ b/src/java/nginx/unit/websocket/DigestAuthenticator.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.nio.charset.StandardCharsets;
+import java.security.MessageDigest;
+import java.security.NoSuchAlgorithmException;
+import java.security.SecureRandom;
+import java.util.Map;
+
+import org.apache.tomcat.util.security.MD5Encoder;
+
+/**
+ * Authenticator supporting the DIGEST auth method.
+ */
+public class DigestAuthenticator extends Authenticator {
+
+ public static final String schemeName = "digest";
+ private SecureRandom cnonceGenerator;
+ private int nonceCount = 0;
+ private long cNonce;
+
+ @Override
+ public String getAuthorization(String requestUri, String WWWAuthenticate,
+ Map<String, Object> userProperties) throws AuthenticationException {
+
+ String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME);
+ String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD);
+
+ if (userName == null || password == null) {
+ throw new AuthenticationException(
+ "Failed to perform Digest authentication due to missing user/password");
+ }
+
+ Map<String, String> wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate);
+
+ String realm = wwwAuthenticate.get("realm");
+ String nonce = wwwAuthenticate.get("nonce");
+ String messageQop = wwwAuthenticate.get("qop");
+ String algorithm = wwwAuthenticate.get("algorithm") == null ? "MD5"
+ : wwwAuthenticate.get("algorithm");
+ String opaque = wwwAuthenticate.get("opaque");
+
+ StringBuilder challenge = new StringBuilder();
+
+ if (!messageQop.isEmpty()) {
+ if (cnonceGenerator == null) {
+ cnonceGenerator = new SecureRandom();
+ }
+
+ cNonce = cnonceGenerator.nextLong();
+ nonceCount++;
+ }
+
+ challenge.append("Digest ");
+ challenge.append("username =\"" + userName + "\",");
+ challenge.append("realm=\"" + realm + "\",");
+ challenge.append("nonce=\"" + nonce + "\",");
+ challenge.append("uri=\"" + requestUri + "\",");
+
+ try {
+ challenge.append("response=\"" + calculateRequestDigest(requestUri, userName, password,
+ realm, nonce, messageQop, algorithm) + "\",");
+ }
+
+ catch (NoSuchAlgorithmException e) {
+ throw new AuthenticationException(
+ "Unable to generate request digest " + e.getMessage());
+ }
+
+ challenge.append("algorithm=" + algorithm + ",");
+ challenge.append("opaque=\"" + opaque + "\",");
+
+ if (!messageQop.isEmpty()) {
+ challenge.append("qop=\"" + messageQop + "\"");
+ challenge.append(",cnonce=\"" + cNonce + "\",");
+ challenge.append("nc=" + String.format("%08X", Integer.valueOf(nonceCount)));
+ }
+
+ return challenge.toString();
+
+ }
+
+ private String calculateRequestDigest(String requestUri, String userName, String password,
+ String realm, String nonce, String qop, String algorithm)
+ throws NoSuchAlgorithmException {
+
+ StringBuilder preDigest = new StringBuilder();
+ String A1;
+
+ if (algorithm.equalsIgnoreCase("MD5"))
+ A1 = userName + ":" + realm + ":" + password;
+
+ else
+ A1 = encodeMD5(userName + ":" + realm + ":" + password) + ":" + nonce + ":" + cNonce;
+
+ /*
+ * If the "qop" value is "auth-int", then A2 is: A2 = Method ":"
+ * digest-uri-value ":" H(entity-body) since we do not have an entity-body, A2 =
+ * Method ":" digest-uri-value for auth and auth_int
+ */
+ String A2 = "GET:" + requestUri;
+
+ preDigest.append(encodeMD5(A1));
+ preDigest.append(":");
+ preDigest.append(nonce);
+
+ if (qop.toLowerCase().contains("auth")) {
+ preDigest.append(":");
+ preDigest.append(String.format("%08X", Integer.valueOf(nonceCount)));
+ preDigest.append(":");
+ preDigest.append(String.valueOf(cNonce));
+ preDigest.append(":");
+ preDigest.append(qop);
+ }
+
+ preDigest.append(":");
+ preDigest.append(encodeMD5(A2));
+
+ return encodeMD5(preDigest.toString());
+
+ }
+
+ private String encodeMD5(String value) throws NoSuchAlgorithmException {
+ byte[] bytesOfMessage = value.getBytes(StandardCharsets.ISO_8859_1);
+ MessageDigest md = MessageDigest.getInstance("MD5");
+ byte[] thedigest = md.digest(bytesOfMessage);
+
+ return MD5Encoder.encode(thedigest);
+ }
+
+ @Override
+ public String getSchemeName() {
+ return schemeName;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/FutureToSendHandler.java b/src/java/nginx/unit/websocket/FutureToSendHandler.java
new file mode 100644
index 00000000..4a0809cb
--- /dev/null
+++ b/src/java/nginx/unit/websocket/FutureToSendHandler.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
+
+import javax.websocket.SendHandler;
+import javax.websocket.SendResult;
+
+import org.apache.tomcat.util.res.StringManager;
+
+
+/**
+ * Converts a Future to a SendHandler.
+ */
+class FutureToSendHandler implements Future<Void>, SendHandler {
+
+ private static final StringManager sm = StringManager.getManager(FutureToSendHandler.class);
+
+ private final CountDownLatch latch = new CountDownLatch(1);
+ private final WsSession wsSession;
+ private volatile AtomicReference<SendResult> result = new AtomicReference<>(null);
+
+ public FutureToSendHandler(WsSession wsSession) {
+ this.wsSession = wsSession;
+ }
+
+
+ // --------------------------------------------------------- SendHandler
+
+ @Override
+ public void onResult(SendResult result) {
+ this.result.compareAndSet(null, result);
+ latch.countDown();
+ }
+
+
+ // -------------------------------------------------------------- Future
+
+ @Override
+ public boolean cancel(boolean mayInterruptIfRunning) {
+ // Cancelling the task is not supported
+ return false;
+ }
+
+ @Override
+ public boolean isCancelled() {
+ // Cancelling the task is not supported
+ return false;
+ }
+
+ @Override
+ public boolean isDone() {
+ return latch.getCount() == 0;
+ }
+
+ @Override
+ public Void get() throws InterruptedException,
+ ExecutionException {
+ try {
+ wsSession.registerFuture(this);
+ latch.await();
+ } finally {
+ wsSession.unregisterFuture(this);
+ }
+ if (result.get().getException() != null) {
+ throw new ExecutionException(result.get().getException());
+ }
+ return null;
+ }
+
+ @Override
+ public Void get(long timeout, TimeUnit unit)
+ throws InterruptedException, ExecutionException,
+ TimeoutException {
+ boolean retval = false;
+ try {
+ wsSession.registerFuture(this);
+ retval = latch.await(timeout, unit);
+ } finally {
+ wsSession.unregisterFuture(this);
+
+ }
+ if (retval == false) {
+ throw new TimeoutException(sm.getString("futureToSendHandler.timeout",
+ Long.valueOf(timeout), unit.toString().toLowerCase()));
+ }
+ if (result.get().getException() != null) {
+ throw new ExecutionException(result.get().getException());
+ }
+ return null;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/LocalStrings.properties b/src/java/nginx/unit/websocket/LocalStrings.properties
new file mode 100644
index 00000000..aeafe082
--- /dev/null
+++ b/src/java/nginx/unit/websocket/LocalStrings.properties
@@ -0,0 +1,147 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+asyncChannelGroup.createFail=Unable to create dedicated AsynchronousChannelGroup for WebSocket clients which is required to prevent memory leaks in complex class loader environments like JavaEE containers
+
+asyncChannelWrapperSecure.closeFail=Failed to close channel cleanly
+asyncChannelWrapperSecure.check.notOk=TLS handshake returned an unexpected status [{0}]
+asyncChannelWrapperSecure.check.unwrap=Bytes were written to the output during a read
+asyncChannelWrapperSecure.check.wrap=Bytes were consumed from the input during a write
+asyncChannelWrapperSecure.concurrentRead=Concurrent read operations are not permitted
+asyncChannelWrapperSecure.concurrentWrite=Concurrent write operations are not permitted
+asyncChannelWrapperSecure.eof=Unexpected end of stream
+asyncChannelWrapperSecure.notHandshaking=Unexpected state [NOT_HANDSHAKING] during TLS handshake
+asyncChannelWrapperSecure.readOverflow=Buffer overflow. [{0}] bytes to write into a [{1}] byte buffer that already contained [{2}] bytes.
+asyncChannelWrapperSecure.statusUnwrap=Unexpected Status of SSLEngineResult after an unwrap() operation
+asyncChannelWrapperSecure.statusWrap=Unexpected Status of SSLEngineResult after a wrap() operation
+asyncChannelWrapperSecure.tooBig=The result [{0}] is too big to be expressed as an Integer
+asyncChannelWrapperSecure.wrongStateRead=Flag that indicates a read is in progress was found to be false (it should have been true) when trying to complete a read operation
+asyncChannelWrapperSecure.wrongStateWrite=Flag that indicates a write is in progress was found to be false (it should have been true) when trying to complete a write operation
+
+backgroundProcessManager.processFailed=A background process failed
+
+caseInsensitiveKeyMap.nullKey=Null keys are not permitted
+
+futureToSendHandler.timeout=Operation timed out after waiting [{0}] [{1}] to complete
+
+perMessageDeflate.deflateFailed=Failed to decompress a compressed WebSocket frame
+perMessageDeflate.duplicateParameter=Duplicate definition of the [{0}] extension parameter
+perMessageDeflate.invalidWindowSize=An invalid windows of [{1}] size was specified for [{0}]. Valid values are whole numbers from 8 to 15 inclusive.
+perMessageDeflate.unknownParameter=An unknown extension parameter [{0}] was defined
+
+transformerFactory.unsupportedExtension=The extension [{0}] is not supported
+
+util.notToken=An illegal extension parameter was specified with name [{0}] and value [{1}]
+util.invalidMessageHandler=The message handler provided does not have an onMessage(Object) method
+util.invalidType=Unable to coerce value [{0}] to type [{1}]. That type is not supported.
+util.unknownDecoderType=The Decoder type [{0}] is not recognized
+
+# Note the wsFrame.* messages are used as close reasons in WebSocket control
+# frames and therefore must be 123 bytes (not characters) or less in length.
+# Messages are encoded using UTF-8 where a single character may be encoded in
+# as many as 4 bytes.
+wsFrame.alreadyResumed=Message receiving has already been resumed.
+wsFrame.alreadySuspended=Message receiving has already been suspended.
+wsFrame.bufferTooSmall=No async message support and buffer too small. Buffer size: [{0}], Message size: [{1}]
+wsFrame.byteToLongFail=Too many bytes ([{0}]) were provided to be converted into a long
+wsFrame.closed=New frame received after a close control frame
+wsFrame.controlFragmented=A fragmented control frame was received but control frames may not be fragmented
+wsFrame.controlPayloadTooBig=A control frame was sent with a payload of size [{0}] which is larger than the maximum permitted of 125 bytes
+wsFrame.controlNoFin=A control frame was sent that did not have the fin bit set. Control frames are not permitted to use continuation frames.
+wsFrame.illegalReadState=Unexpected read state [{0}]
+wsFrame.invalidOpCode= A WebSocket frame was sent with an unrecognised opCode of [{0}]
+wsFrame.invalidUtf8=A WebSocket text frame was received that could not be decoded to UTF-8 because it contained invalid byte sequences
+wsFrame.invalidUtf8Close=A WebSocket close frame was received with a close reason that contained invalid UTF-8 byte sequences
+wsFrame.ioeTriggeredClose=An unrecoverable IOException occurred so the connection was closed
+wsFrame.messageTooBig=The message was [{0}] bytes long but the MessageHandler has a limit of [{1}] bytes
+wsFrame.noContinuation=A new message was started when a continuation frame was expected
+wsFrame.notMasked=The client frame was not masked but all client frames must be masked
+wsFrame.oneByteCloseCode=The client sent a close frame with a single byte payload which is not valid
+wsFrame.partialHeaderComplete=WebSocket frame received. fin [{0}], rsv [{1}], OpCode [{2}], payload length [{3}]
+wsFrame.sessionClosed=The client data cannot be processed because the session has already been closed
+wsFrame.suspendRequested=Suspend of the message receiving has already been requested.
+wsFrame.textMessageTooBig=The decoded text message was too big for the output buffer and the endpoint does not support partial messages
+wsFrame.wrongRsv=The client frame set the reserved bits to [{0}] for a message with opCode [{1}] which was not supported by this endpoint
+
+wsFrameClient.ioe=Failure while reading data sent by server
+
+wsHandshakeRequest.invalidUri=The string [{0}] cannot be used to construct a valid URI
+wsHandshakeRequest.unknownScheme=The scheme [{0}] in the request is not recognised
+
+wsRemoteEndpoint.acquireTimeout=The current message was not fully sent within the specified timeout
+wsRemoteEndpoint.closed=Message will not be sent because the WebSocket session has been closed
+wsRemoteEndpoint.closedDuringMessage=The remainder of the message will not be sent because the WebSocket session has been closed
+wsRemoteEndpoint.closedOutputStream=This method may not be called as the OutputStream has been closed
+wsRemoteEndpoint.closedWriter=This method may not be called as the Writer has been closed
+wsRemoteEndpoint.changeType=When sending a fragmented message, all fragments must be of the same type
+wsRemoteEndpoint.concurrentMessageSend=Messages may not be sent concurrently even when using the asynchronous send messages. The client must wait for the previous message to complete before sending the next.
+wsRemoteEndpoint.flushOnCloseFailed=Batched messages still enabled after session has been closed. Unable to flush remaining batched message.
+wsRemoteEndpoint.invalidEncoder=The specified encoder of type [{0}] could not be instantiated
+wsRemoteEndpoint.noEncoder=No encoder specified for object of class [{0}]
+wsRemoteEndpoint.nullData=Invalid null data argument
+wsRemoteEndpoint.nullHandler=Invalid null handler argument
+wsRemoteEndpoint.sendInterrupt=The current thread was interrupted while waiting for a blocking send to complete
+wsRemoteEndpoint.tooMuchData=Ping or pong may not send more than 125 bytes
+wsRemoteEndpoint.wrongState=The remote endpoint was in state [{0}] which is an invalid state for called method
+
+# Note the following message is used as a close reason in a WebSocket control
+# frame and therefore must be 123 bytes (not characters) or less in length.
+# Messages are encoded using UTF-8 where a single character may be encoded in
+# as many as 4 bytes.
+wsSession.timeout=The WebSocket session [{0}] timeout expired
+
+wsSession.closed=The WebSocket session [{0}] has been closed and no method (apart from close()) may be called on a closed session
+wsSession.created=Created WebSocket session [{0}]
+wsSession.doClose=Closing WebSocket session [{1}]
+wsSession.duplicateHandlerBinary=A binary message handler has already been configured
+wsSession.duplicateHandlerPong=A pong message handler has already been configured
+wsSession.duplicateHandlerText=A text message handler has already been configured
+wsSession.invalidHandlerTypePong=A pong message handler must implement MessageHandler.Whole
+wsSession.flushFailOnClose=Failed to flush batched messages on session close
+wsSession.messageFailed=Unable to write the complete message as the WebSocket connection has been closed
+wsSession.sendCloseFail=Failed to send close message for session [{0}] to remote endpoint
+wsSession.removeHandlerFailed=Unable to remove the handler [{0}] as it was not registered with this session
+wsSession.unknownHandler=Unable to add the message handler [{0}] as it was for the unrecognised type [{1}]
+wsSession.unknownHandlerType=Unable to add the message handler [{0}] as it was wrapped as the unrecognised type [{1}]
+wsSession.instanceNew=Endpoint instance registration failed
+wsSession.instanceDestroy=Endpoint instance unregistration failed
+
+# Note the following message is used as a close reason in a WebSocket control
+# frame and therefore must be 123 bytes (not characters) or less in length.
+# Messages are encoded using UTF-8 where a single character may be encoded in
+# as many as 4 bytes.
+wsWebSocketContainer.shutdown=The web application is stopping
+
+wsWebSocketContainer.defaultConfiguratorFail=Failed to create the default configurator
+wsWebSocketContainer.endpointCreateFail=Failed to create a local endpoint of type [{0}]
+wsWebSocketContainer.maxBuffer=This implementation limits the maximum size of a buffer to Integer.MAX_VALUE
+wsWebSocketContainer.missingAnnotation=Cannot use POJO class [{0}] as it is not annotated with @ClientEndpoint
+wsWebSocketContainer.sessionCloseFail=Session with ID [{0}] did not close cleanly
+
+wsWebSocketContainer.asynchronousSocketChannelFail=Unable to open a connection to the server
+wsWebSocketContainer.httpRequestFailed=The HTTP request to initiate the WebSocket connection failed
+wsWebSocketContainer.invalidExtensionParameters=The server responded with extension parameters the client is unable to support
+wsWebSocketContainer.invalidHeader=Unable to parse HTTP header as no colon is present to delimit header name and header value in [{0}]. The header has been skipped.
+wsWebSocketContainer.invalidStatus=The HTTP response from the server [{0}] did not permit the HTTP upgrade to WebSocket
+wsWebSocketContainer.invalidSubProtocol=The WebSocket server returned multiple values for the Sec-WebSocket-Protocol header
+wsWebSocketContainer.pathNoHost=No host was specified in URI
+wsWebSocketContainer.pathWrongScheme=The scheme [{0}] is not supported. The supported schemes are ws and wss
+wsWebSocketContainer.proxyConnectFail=Failed to connect to the configured Proxy [{0}]. The HTTP response code was [{1}]
+wsWebSocketContainer.sslEngineFail=Unable to create SSLEngine to support SSL/TLS connections
+wsWebSocketContainer.missingLocationHeader=Failed to handle HTTP response code [{0}]. Missing Location header in response
+wsWebSocketContainer.redirectThreshold=Cyclic Location header [{0}] detected / reached max number of redirects [{1}] of max [{2}]
+wsWebSocketContainer.unsupportedAuthScheme=Failed to handle HTTP response code [{0}]. Unsupported Authentication scheme [{1}] returned in response
+wsWebSocketContainer.failedAuthentication=Failed to handle HTTP response code [{0}]. Authentication header was not accepted by server.
+wsWebSocketContainer.missingWWWAuthenticateHeader=Failed to handle HTTP response code [{0}]. Missing WWW-Authenticate header in response
diff --git a/src/java/nginx/unit/websocket/MessageHandlerResult.java b/src/java/nginx/unit/websocket/MessageHandlerResult.java
new file mode 100644
index 00000000..8d532d1e
--- /dev/null
+++ b/src/java/nginx/unit/websocket/MessageHandlerResult.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import javax.websocket.MessageHandler;
+
+public class MessageHandlerResult {
+
+ private final MessageHandler handler;
+ private final MessageHandlerResultType type;
+
+
+ public MessageHandlerResult(MessageHandler handler,
+ MessageHandlerResultType type) {
+ this.handler = handler;
+ this.type = type;
+ }
+
+
+ public MessageHandler getHandler() {
+ return handler;
+ }
+
+
+ public MessageHandlerResultType getType() {
+ return type;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/MessageHandlerResultType.java b/src/java/nginx/unit/websocket/MessageHandlerResultType.java
new file mode 100644
index 00000000..1961bb4f
--- /dev/null
+++ b/src/java/nginx/unit/websocket/MessageHandlerResultType.java
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+public enum MessageHandlerResultType {
+ BINARY,
+ TEXT,
+ PONG
+}
diff --git a/src/java/nginx/unit/websocket/MessagePart.java b/src/java/nginx/unit/websocket/MessagePart.java
new file mode 100644
index 00000000..b52c26f1
--- /dev/null
+++ b/src/java/nginx/unit/websocket/MessagePart.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.nio.ByteBuffer;
+
+import javax.websocket.SendHandler;
+
+class MessagePart {
+ private final boolean fin;
+ private final int rsv;
+ private final byte opCode;
+ private final ByteBuffer payload;
+ private final SendHandler intermediateHandler;
+ private volatile SendHandler endHandler;
+ private final long blockingWriteTimeoutExpiry;
+
+ public MessagePart( boolean fin, int rsv, byte opCode, ByteBuffer payload,
+ SendHandler intermediateHandler, SendHandler endHandler,
+ long blockingWriteTimeoutExpiry) {
+ this.fin = fin;
+ this.rsv = rsv;
+ this.opCode = opCode;
+ this.payload = payload;
+ this.intermediateHandler = intermediateHandler;
+ this.endHandler = endHandler;
+ this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry;
+ }
+
+
+ public boolean isFin() {
+ return fin;
+ }
+
+
+ public int getRsv() {
+ return rsv;
+ }
+
+
+ public byte getOpCode() {
+ return opCode;
+ }
+
+
+ public ByteBuffer getPayload() {
+ return payload;
+ }
+
+
+ public SendHandler getIntermediateHandler() {
+ return intermediateHandler;
+ }
+
+
+ public SendHandler getEndHandler() {
+ return endHandler;
+ }
+
+ public void setEndHandler(SendHandler endHandler) {
+ this.endHandler = endHandler;
+ }
+
+ public long getBlockingWriteTimeoutExpiry() {
+ return blockingWriteTimeoutExpiry;
+ }
+}
+
+
diff --git a/src/java/nginx/unit/websocket/PerMessageDeflate.java b/src/java/nginx/unit/websocket/PerMessageDeflate.java
new file mode 100644
index 00000000..88e0a0bc
--- /dev/null
+++ b/src/java/nginx/unit/websocket/PerMessageDeflate.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.zip.DataFormatException;
+import java.util.zip.Deflater;
+import java.util.zip.Inflater;
+
+import javax.websocket.Extension;
+import javax.websocket.Extension.Parameter;
+import javax.websocket.SendHandler;
+
+import org.apache.tomcat.util.res.StringManager;
+
+public class PerMessageDeflate implements Transformation {
+
+ private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class);
+
+ private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover";
+ private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover";
+ private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits";
+ private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";
+
+ private static final int RSV_BITMASK = 0b100;
+ private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1};
+
+ public static final String NAME = "permessage-deflate";
+
+ private final boolean serverContextTakeover;
+ private final int serverMaxWindowBits;
+ private final boolean clientContextTakeover;
+ private final int clientMaxWindowBits;
+ private final boolean isServer;
+ private final Inflater inflater = new Inflater(true);
+ private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
+ private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1];
+
+ private volatile Transformation next;
+ private volatile boolean skipDecompression = false;
+ private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private volatile boolean firstCompressedFrameWritten = false;
+ // Flag to track if a message is completely empty
+ private volatile boolean emptyMessage = true;
+
+ static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) {
+ // Accept the first preference that the endpoint is able to support
+ for (List<Parameter> preference : preferences) {
+ boolean ok = true;
+ boolean serverContextTakeover = true;
+ int serverMaxWindowBits = -1;
+ boolean clientContextTakeover = true;
+ int clientMaxWindowBits = -1;
+
+ for (Parameter param : preference) {
+ if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
+ if (serverContextTakeover) {
+ serverContextTakeover = false;
+ } else {
+ // Duplicate definition
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.duplicateParameter",
+ SERVER_NO_CONTEXT_TAKEOVER ));
+ }
+ } else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
+ if (clientContextTakeover) {
+ clientContextTakeover = false;
+ } else {
+ // Duplicate definition
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.duplicateParameter",
+ CLIENT_NO_CONTEXT_TAKEOVER ));
+ }
+ } else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) {
+ if (serverMaxWindowBits == -1) {
+ serverMaxWindowBits = Integer.parseInt(param.getValue());
+ if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) {
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.invalidWindowSize",
+ SERVER_MAX_WINDOW_BITS,
+ Integer.valueOf(serverMaxWindowBits)));
+ }
+ // Java SE API (as of Java 8) does not expose the API to
+ // control the Window size. It is effectively hard-coded
+ // to 15
+ if (isServer && serverMaxWindowBits != 15) {
+ ok = false;
+ break;
+ // Note server window size is not an issue for the
+ // client since the client will assume 15 and if the
+ // server uses a smaller window everything will
+ // still work
+ }
+ } else {
+ // Duplicate definition
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.duplicateParameter",
+ SERVER_MAX_WINDOW_BITS ));
+ }
+ } else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) {
+ if (clientMaxWindowBits == -1) {
+ if (param.getValue() == null) {
+ // Hint to server that the client supports this
+ // option. Java SE API (as of Java 8) does not
+ // expose the API to control the Window size. It is
+ // effectively hard-coded to 15
+ clientMaxWindowBits = 15;
+ } else {
+ clientMaxWindowBits = Integer.parseInt(param.getValue());
+ if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) {
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.invalidWindowSize",
+ CLIENT_MAX_WINDOW_BITS,
+ Integer.valueOf(clientMaxWindowBits)));
+ }
+ }
+ // Java SE API (as of Java 8) does not expose the API to
+ // control the Window size. It is effectively hard-coded
+ // to 15
+ if (!isServer && clientMaxWindowBits != 15) {
+ ok = false;
+ break;
+ // Note client window size is not an issue for the
+ // server since the server will assume 15 and if the
+ // client uses a smaller window everything will
+ // still work
+ }
+ } else {
+ // Duplicate definition
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.duplicateParameter",
+ CLIENT_MAX_WINDOW_BITS ));
+ }
+ } else {
+ // Unknown parameter
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.unknownParameter", param.getName()));
+ }
+ }
+ if (ok) {
+ return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits,
+ clientContextTakeover, clientMaxWindowBits, isServer);
+ }
+ }
+ // Failed to negotiate agreeable terms
+ return null;
+ }
+
+
+ private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits,
+ boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) {
+ this.serverContextTakeover = serverContextTakeover;
+ this.serverMaxWindowBits = serverMaxWindowBits;
+ this.clientContextTakeover = clientContextTakeover;
+ this.clientMaxWindowBits = clientMaxWindowBits;
+ this.isServer = isServer;
+ }
+
+
+ @Override
+ public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)
+ throws IOException {
+ // Control frames are never compressed and may appear in the middle of
+ // a WebSocket method. Pass them straight through.
+ if (Util.isControl(opCode)) {
+ return next.getMoreData(opCode, fin, rsv, dest);
+ }
+
+ if (!Util.isContinuation(opCode)) {
+ // First frame in new message
+ skipDecompression = (rsv & RSV_BITMASK) == 0;
+ }
+
+ // Pass uncompressed frames straight through.
+ if (skipDecompression) {
+ return next.getMoreData(opCode, fin, rsv, dest);
+ }
+
+ int written;
+ boolean usedEomBytes = false;
+
+ while (dest.remaining() > 0) {
+ // Space available in destination. Try and fill it.
+ try {
+ written = inflater.inflate(
+ dest.array(), dest.arrayOffset() + dest.position(), dest.remaining());
+ } catch (DataFormatException e) {
+ throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e);
+ }
+ dest.position(dest.position() + written);
+
+ if (inflater.needsInput() && !usedEomBytes ) {
+ if (dest.hasRemaining()) {
+ readBuffer.clear();
+ TransformationResult nextResult =
+ next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer);
+ inflater.setInput(
+ readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position());
+ if (TransformationResult.UNDERFLOW.equals(nextResult)) {
+ return nextResult;
+ } else if (TransformationResult.END_OF_FRAME.equals(nextResult) &&
+ readBuffer.position() == 0) {
+ if (fin) {
+ inflater.setInput(EOM_BYTES);
+ usedEomBytes = true;
+ } else {
+ return TransformationResult.END_OF_FRAME;
+ }
+ }
+ }
+ } else if (written == 0) {
+ if (fin && (isServer && !clientContextTakeover ||
+ !isServer && !serverContextTakeover)) {
+ inflater.reset();
+ }
+ return TransformationResult.END_OF_FRAME;
+ }
+ }
+
+ return TransformationResult.OVERFLOW;
+ }
+
+
+ @Override
+ public boolean validateRsv(int rsv, byte opCode) {
+ if (Util.isControl(opCode)) {
+ if ((rsv & RSV_BITMASK) != 0) {
+ return false;
+ } else {
+ if (next == null) {
+ return true;
+ } else {
+ return next.validateRsv(rsv, opCode);
+ }
+ }
+ } else {
+ int rsvNext = rsv;
+ if ((rsv & RSV_BITMASK) != 0) {
+ rsvNext = rsv ^ RSV_BITMASK;
+ }
+ if (next == null) {
+ return true;
+ } else {
+ return next.validateRsv(rsvNext, opCode);
+ }
+ }
+ }
+
+
+ @Override
+ public Extension getExtensionResponse() {
+ Extension result = new WsExtension(NAME);
+
+ List<Extension.Parameter> params = result.getParameters();
+
+ if (!serverContextTakeover) {
+ params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null));
+ }
+ if (serverMaxWindowBits != -1) {
+ params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS,
+ Integer.toString(serverMaxWindowBits)));
+ }
+ if (!clientContextTakeover) {
+ params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null));
+ }
+ if (clientMaxWindowBits != -1) {
+ params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS,
+ Integer.toString(clientMaxWindowBits)));
+ }
+
+ return result;
+ }
+
+
+ @Override
+ public void setNext(Transformation t) {
+ if (next == null) {
+ this.next = t;
+ } else {
+ next.setNext(t);
+ }
+ }
+
+
+ @Override
+ public boolean validateRsvBits(int i) {
+ if ((i & RSV_BITMASK) != 0) {
+ return false;
+ }
+ if (next == null) {
+ return true;
+ } else {
+ return next.validateRsvBits(i | RSV_BITMASK);
+ }
+ }
+
+
+ @Override
+ public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) {
+ List<MessagePart> allCompressedParts = new ArrayList<>();
+
+ for (MessagePart uncompressedPart : uncompressedParts) {
+ byte opCode = uncompressedPart.getOpCode();
+ boolean emptyPart = uncompressedPart.getPayload().limit() == 0;
+ emptyMessage = emptyMessage && emptyPart;
+ if (Util.isControl(opCode)) {
+ // Control messages can appear in the middle of other messages
+ // and must not be compressed. Pass it straight through
+ allCompressedParts.add(uncompressedPart);
+ } else if (emptyMessage && uncompressedPart.isFin()) {
+ // Zero length messages can't be compressed so pass the
+ // final (empty) part straight through.
+ allCompressedParts.add(uncompressedPart);
+ } else {
+ List<MessagePart> compressedParts = new ArrayList<>();
+ ByteBuffer uncompressedPayload = uncompressedPart.getPayload();
+ SendHandler uncompressedIntermediateHandler =
+ uncompressedPart.getIntermediateHandler();
+
+ deflater.setInput(uncompressedPayload.array(),
+ uncompressedPayload.arrayOffset() + uncompressedPayload.position(),
+ uncompressedPayload.remaining());
+
+ int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH);
+ boolean deflateRequired = true;
+
+ while (deflateRequired) {
+ ByteBuffer compressedPayload = writeBuffer;
+
+ int written = deflater.deflate(compressedPayload.array(),
+ compressedPayload.arrayOffset() + compressedPayload.position(),
+ compressedPayload.remaining(), flush);
+ compressedPayload.position(compressedPayload.position() + written);
+
+ if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) {
+ // This message part has been fully processed by the
+ // deflater. Fire the send handler for this message part
+ // and move on to the next message part.
+ break;
+ }
+
+ // If this point is reached, a new compressed message part
+ // will be created...
+ MessagePart compressedPart;
+
+ // .. and a new writeBuffer will be required.
+ writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+
+ // Flip the compressed payload ready for writing
+ compressedPayload.flip();
+
+ boolean fin = uncompressedPart.isFin();
+ boolean full = compressedPayload.limit() == compressedPayload.capacity();
+ boolean needsInput = deflater.needsInput();
+ long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry();
+
+ if (fin && !full && needsInput) {
+ // End of compressed message. Drop EOM bytes and output.
+ compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length);
+ compressedPart = new MessagePart(true, getRsv(uncompressedPart),
+ opCode, compressedPayload, uncompressedIntermediateHandler,
+ uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
+ deflateRequired = false;
+ startNewMessage();
+ } else if (full && !needsInput) {
+ // Write buffer full and input message not fully read.
+ // Output and start new compressed part.
+ compressedPart = new MessagePart(false, getRsv(uncompressedPart),
+ opCode, compressedPayload, uncompressedIntermediateHandler,
+ uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
+ } else if (!fin && full && needsInput) {
+ // Write buffer full and input message not fully read.
+ // Output and get more data.
+ compressedPart = new MessagePart(false, getRsv(uncompressedPart),
+ opCode, compressedPayload, uncompressedIntermediateHandler,
+ uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
+ deflateRequired = false;
+ } else if (fin && full && needsInput) {
+ // Write buffer full. Input fully read. Deflater may be
+ // in one of four states:
+ // - output complete (just happened to align with end of
+ // buffer
+ // - in middle of EOM bytes
+ // - about to write EOM bytes
+ // - more data to write
+ int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH);
+ if (eomBufferWritten < EOM_BUFFER.length) {
+ // EOM has just been completed
+ compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten);
+ compressedPart = new MessagePart(true,
+ getRsv(uncompressedPart), opCode, compressedPayload,
+ uncompressedIntermediateHandler, uncompressedIntermediateHandler,
+ blockingWriteTimeoutExpiry);
+ deflateRequired = false;
+ startNewMessage();
+ } else {
+ // More data to write
+ // Copy bytes to new write buffer
+ writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten);
+ compressedPart = new MessagePart(false,
+ getRsv(uncompressedPart), opCode, compressedPayload,
+ uncompressedIntermediateHandler, uncompressedIntermediateHandler,
+ blockingWriteTimeoutExpiry);
+ }
+ } else {
+ throw new IllegalStateException("Should never happen");
+ }
+
+ // Add the newly created compressed part to the set of parts
+ // to pass on to the next transformation.
+ compressedParts.add(compressedPart);
+ }
+
+ SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler();
+ int size = compressedParts.size();
+ if (size > 0) {
+ compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler);
+ }
+
+ allCompressedParts.addAll(compressedParts);
+ }
+ }
+
+ if (next == null) {
+ return allCompressedParts;
+ } else {
+ return next.sendMessagePart(allCompressedParts);
+ }
+ }
+
+
+ private void startNewMessage() {
+ firstCompressedFrameWritten = false;
+ emptyMessage = true;
+ if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) {
+ deflater.reset();
+ }
+ }
+
+
+ private int getRsv(MessagePart uncompressedMessagePart) {
+ int result = uncompressedMessagePart.getRsv();
+ if (!firstCompressedFrameWritten) {
+ result += RSV_BITMASK;
+ firstCompressedFrameWritten = true;
+ }
+ return result;
+ }
+
+
+ @Override
+ public void close() {
+ // There will always be a next transformation
+ next.close();
+ inflater.end();
+ deflater.end();
+ }
+}
diff --git a/src/java/nginx/unit/websocket/ReadBufferOverflowException.java b/src/java/nginx/unit/websocket/ReadBufferOverflowException.java
new file mode 100644
index 00000000..9ce7ac27
--- /dev/null
+++ b/src/java/nginx/unit/websocket/ReadBufferOverflowException.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+
+public class ReadBufferOverflowException extends IOException {
+
+ private static final long serialVersionUID = 1L;
+
+ private final int minBufferSize;
+
+ public ReadBufferOverflowException(int minBufferSize) {
+ this.minBufferSize = minBufferSize;
+ }
+
+ public int getMinBufferSize() {
+ return minBufferSize;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/Transformation.java b/src/java/nginx/unit/websocket/Transformation.java
new file mode 100644
index 00000000..45474c7d
--- /dev/null
+++ b/src/java/nginx/unit/websocket/Transformation.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.List;
+
+import javax.websocket.Extension;
+
+/**
+ * The internal representation of the transformation that a WebSocket extension
+ * performs on a message.
+ */
+public interface Transformation {
+
+ /**
+ * Sets the next transformation in the pipeline.
+ * @param t The next transformation
+ */
+ void setNext(Transformation t);
+
+ /**
+ * Validate that the RSV bit(s) required by this transformation are not
+ * being used by another extension. The implementation is expected to set
+ * any bits it requires before passing the set of in-use bits to the next
+ * transformation.
+ *
+ * @param i The RSV bits marked as in use so far as an int in the
+ * range zero to seven with RSV1 as the MSB and RSV3 as the
+ * LSB
+ *
+ * @return <code>true</code> if the combination of RSV bits used by the
+ * transformations in the pipeline do not conflict otherwise
+ * <code>false</code>
+ */
+ boolean validateRsvBits(int i);
+
+ /**
+ * Obtain the extension that describes the information to be returned to the
+ * client.
+ *
+ * @return The extension information that describes the parameters that have
+ * been agreed for this transformation
+ */
+ Extension getExtensionResponse();
+
+ /**
+ * Obtain more input data.
+ *
+ * @param opCode The opcode for the frame currently being processed
+ * @param fin Is this the final frame in this WebSocket message?
+ * @param rsv The reserved bits for the frame currently being
+ * processed
+ * @param dest The buffer in which the data is to be written
+ *
+ * @return The result of trying to read more data from the transform
+ *
+ * @throws IOException If an I/O error occurs while reading data from the
+ * transform
+ */
+ TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) throws IOException;
+
+ /**
+ * Validates the RSV and opcode combination (assumed to have been extracted
+ * from a WebSocket Frame) for this extension. The implementation is
+ * expected to unset any RSV bits it has validated before passing the
+ * remaining RSV bits to the next transformation in the pipeline.
+ *
+ * @param rsv The RSV bits received as an int in the range zero to
+ * seven with RSV1 as the MSB and RSV3 as the LSB
+ * @param opCode The opCode received
+ *
+ * @return <code>true</code> if the RSV is valid otherwise
+ * <code>false</code>
+ */
+ boolean validateRsv(int rsv, byte opCode);
+
+ /**
+ * Takes the provided list of messages, transforms them, passes the
+ * transformed list on to the next transformation (if any) and then returns
+ * the resulting list of message parts after all of the transformations have
+ * been applied.
+ *
+ * @param messageParts The list of messages to be transformed
+ *
+ * @return The list of messages after this any any subsequent
+ * transformations have been applied. The size of the returned list
+ * may be bigger or smaller than the size of the input list
+ */
+ List<MessagePart> sendMessagePart(List<MessagePart> messageParts);
+
+ /**
+ * Clean-up any resources that were used by the transformation.
+ */
+ void close();
+}
diff --git a/src/java/nginx/unit/websocket/TransformationFactory.java b/src/java/nginx/unit/websocket/TransformationFactory.java
new file mode 100644
index 00000000..fac04555
--- /dev/null
+++ b/src/java/nginx/unit/websocket/TransformationFactory.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.util.List;
+
+import javax.websocket.Extension;
+
+import org.apache.tomcat.util.res.StringManager;
+
+public class TransformationFactory {
+
+ private static final StringManager sm = StringManager.getManager(TransformationFactory.class);
+
+ private static final TransformationFactory factory = new TransformationFactory();
+
+ private TransformationFactory() {
+ // Hide default constructor
+ }
+
+ public static TransformationFactory getInstance() {
+ return factory;
+ }
+
+ public Transformation create(String name, List<List<Extension.Parameter>> preferences,
+ boolean isServer) {
+ if (PerMessageDeflate.NAME.equals(name)) {
+ return PerMessageDeflate.negotiate(preferences, isServer);
+ }
+ if (Constants.ALLOW_UNSUPPORTED_EXTENSIONS) {
+ return null;
+ } else {
+ throw new IllegalArgumentException(
+ sm.getString("transformerFactory.unsupportedExtension", name));
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/TransformationResult.java b/src/java/nginx/unit/websocket/TransformationResult.java
new file mode 100644
index 00000000..0de35e55
--- /dev/null
+++ b/src/java/nginx/unit/websocket/TransformationResult.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+public enum TransformationResult {
+ /**
+ * The end of the available data was reached before the WebSocket frame was
+ * completely read.
+ */
+ UNDERFLOW,
+
+ /**
+ * The provided destination buffer was filled before all of the available
+ * data from the WebSocket frame could be processed.
+ */
+ OVERFLOW,
+
+ /**
+ * The end of the WebSocket frame was reached and all the data from that
+ * frame processed into the provided destination buffer.
+ */
+ END_OF_FRAME
+}
diff --git a/src/java/nginx/unit/websocket/Util.java b/src/java/nginx/unit/websocket/Util.java
new file mode 100644
index 00000000..6acf3ade
--- /dev/null
+++ b/src/java/nginx/unit/websocket/Util.java
@@ -0,0 +1,666 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.InputStream;
+import java.io.Reader;
+import java.lang.reflect.GenericArrayType;
+import java.lang.reflect.Method;
+import java.lang.reflect.ParameterizedType;
+import java.lang.reflect.Type;
+import java.lang.reflect.TypeVariable;
+import java.nio.ByteBuffer;
+import java.security.NoSuchAlgorithmException;
+import java.security.SecureRandom;
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Queue;
+import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+import javax.websocket.CloseReason.CloseCode;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.Decoder;
+import javax.websocket.Decoder.Binary;
+import javax.websocket.Decoder.BinaryStream;
+import javax.websocket.Decoder.Text;
+import javax.websocket.Decoder.TextStream;
+import javax.websocket.DeploymentException;
+import javax.websocket.Encoder;
+import javax.websocket.EndpointConfig;
+import javax.websocket.Extension;
+import javax.websocket.MessageHandler;
+import javax.websocket.PongMessage;
+import javax.websocket.Session;
+
+import org.apache.tomcat.util.res.StringManager;
+import nginx.unit.websocket.pojo.PojoMessageHandlerPartialBinary;
+import nginx.unit.websocket.pojo.PojoMessageHandlerWholeBinary;
+import nginx.unit.websocket.pojo.PojoMessageHandlerWholeText;
+
+/**
+ * Utility class for internal use only within the
+ * {@link nginx.unit.websocket} package.
+ */
+public class Util {
+
+ private static final StringManager sm = StringManager.getManager(Util.class);
+ private static final Queue<SecureRandom> randoms =
+ new ConcurrentLinkedQueue<>();
+
+ private Util() {
+ // Hide default constructor
+ }
+
+
+ static boolean isControl(byte opCode) {
+ return (opCode & 0x08) != 0;
+ }
+
+
+ static boolean isText(byte opCode) {
+ return opCode == Constants.OPCODE_TEXT;
+ }
+
+
+ static boolean isContinuation(byte opCode) {
+ return opCode == Constants.OPCODE_CONTINUATION;
+ }
+
+
+ static CloseCode getCloseCode(int code) {
+ if (code > 2999 && code < 5000) {
+ return CloseCodes.getCloseCode(code);
+ }
+ switch (code) {
+ case 1000:
+ return CloseCodes.NORMAL_CLOSURE;
+ case 1001:
+ return CloseCodes.GOING_AWAY;
+ case 1002:
+ return CloseCodes.PROTOCOL_ERROR;
+ case 1003:
+ return CloseCodes.CANNOT_ACCEPT;
+ case 1004:
+ // Should not be used in a close frame
+ // return CloseCodes.RESERVED;
+ return CloseCodes.PROTOCOL_ERROR;
+ case 1005:
+ // Should not be used in a close frame
+ // return CloseCodes.NO_STATUS_CODE;
+ return CloseCodes.PROTOCOL_ERROR;
+ case 1006:
+ // Should not be used in a close frame
+ // return CloseCodes.CLOSED_ABNORMALLY;
+ return CloseCodes.PROTOCOL_ERROR;
+ case 1007:
+ return CloseCodes.NOT_CONSISTENT;
+ case 1008:
+ return CloseCodes.VIOLATED_POLICY;
+ case 1009:
+ return CloseCodes.TOO_BIG;
+ case 1010:
+ return CloseCodes.NO_EXTENSION;
+ case 1011:
+ return CloseCodes.UNEXPECTED_CONDITION;
+ case 1012:
+ // Not in RFC6455
+ // return CloseCodes.SERVICE_RESTART;
+ return CloseCodes.PROTOCOL_ERROR;
+ case 1013:
+ // Not in RFC6455
+ // return CloseCodes.TRY_AGAIN_LATER;
+ return CloseCodes.PROTOCOL_ERROR;
+ case 1015:
+ // Should not be used in a close frame
+ // return CloseCodes.TLS_HANDSHAKE_FAILURE;
+ return CloseCodes.PROTOCOL_ERROR;
+ default:
+ return CloseCodes.PROTOCOL_ERROR;
+ }
+ }
+
+
+ static byte[] generateMask() {
+ // SecureRandom is not thread-safe so need to make sure only one thread
+ // uses it at a time. In theory, the pool could grow to the same size
+ // as the number of request processing threads. In reality it will be
+ // a lot smaller.
+
+ // Get a SecureRandom from the pool
+ SecureRandom sr = randoms.poll();
+
+ // If one isn't available, generate a new one
+ if (sr == null) {
+ try {
+ sr = SecureRandom.getInstance("SHA1PRNG");
+ } catch (NoSuchAlgorithmException e) {
+ // Fall back to platform default
+ sr = new SecureRandom();
+ }
+ }
+
+ // Generate the mask
+ byte[] result = new byte[4];
+ sr.nextBytes(result);
+
+ // Put the SecureRandom back in the poll
+ randoms.add(sr);
+
+ return result;
+ }
+
+
+ static Class<?> getMessageType(MessageHandler listener) {
+ return Util.getGenericType(MessageHandler.class,
+ listener.getClass()).getClazz();
+ }
+
+
+ private static Class<?> getDecoderType(Class<? extends Decoder> decoder) {
+ return Util.getGenericType(Decoder.class, decoder).getClazz();
+ }
+
+
+ static Class<?> getEncoderType(Class<? extends Encoder> encoder) {
+ return Util.getGenericType(Encoder.class, encoder).getClazz();
+ }
+
+
+ private static <T> TypeResult getGenericType(Class<T> type,
+ Class<? extends T> clazz) {
+
+ // Look to see if this class implements the interface of interest
+
+ // Get all the interfaces
+ Type[] interfaces = clazz.getGenericInterfaces();
+ for (Type iface : interfaces) {
+ // Only need to check interfaces that use generics
+ if (iface instanceof ParameterizedType) {
+ ParameterizedType pi = (ParameterizedType) iface;
+ // Look for the interface of interest
+ if (pi.getRawType() instanceof Class) {
+ if (type.isAssignableFrom((Class<?>) pi.getRawType())) {
+ return getTypeParameter(
+ clazz, pi.getActualTypeArguments()[0]);
+ }
+ }
+ }
+ }
+
+ // Interface not found on this class. Look at the superclass.
+ @SuppressWarnings("unchecked")
+ Class<? extends T> superClazz =
+ (Class<? extends T>) clazz.getSuperclass();
+ if (superClazz == null) {
+ // Finished looking up the class hierarchy without finding anything
+ return null;
+ }
+
+ TypeResult superClassTypeResult = getGenericType(type, superClazz);
+ int dimension = superClassTypeResult.getDimension();
+ if (superClassTypeResult.getIndex() == -1 && dimension == 0) {
+ // Superclass implements interface and defines explicit type for
+ // the interface of interest
+ return superClassTypeResult;
+ }
+
+ if (superClassTypeResult.getIndex() > -1) {
+ // Superclass implements interface and defines unknown type for
+ // the interface of interest
+ // Map that unknown type to the generic types defined in this class
+ ParameterizedType superClassType =
+ (ParameterizedType) clazz.getGenericSuperclass();
+ TypeResult result = getTypeParameter(clazz,
+ superClassType.getActualTypeArguments()[
+ superClassTypeResult.getIndex()]);
+ result.incrementDimension(superClassTypeResult.getDimension());
+ if (result.getClazz() != null && result.getDimension() > 0) {
+ superClassTypeResult = result;
+ } else {
+ return result;
+ }
+ }
+
+ if (superClassTypeResult.getDimension() > 0) {
+ StringBuilder className = new StringBuilder();
+ for (int i = 0; i < dimension; i++) {
+ className.append('[');
+ }
+ className.append('L');
+ className.append(superClassTypeResult.getClazz().getCanonicalName());
+ className.append(';');
+
+ Class<?> arrayClazz;
+ try {
+ arrayClazz = Class.forName(className.toString());
+ } catch (ClassNotFoundException e) {
+ throw new IllegalArgumentException(e);
+ }
+
+ return new TypeResult(arrayClazz, -1, 0);
+ }
+
+ // Error will be logged further up the call stack
+ return null;
+ }
+
+
+ /*
+ * For a generic parameter, return either the Class used or if the type
+ * is unknown, the index for the type in definition of the class
+ */
+ private static TypeResult getTypeParameter(Class<?> clazz, Type argType) {
+ if (argType instanceof Class<?>) {
+ return new TypeResult((Class<?>) argType, -1, 0);
+ } else if (argType instanceof ParameterizedType) {
+ return new TypeResult((Class<?>)((ParameterizedType) argType).getRawType(), -1, 0);
+ } else if (argType instanceof GenericArrayType) {
+ Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType();
+ TypeResult result = getTypeParameter(clazz, arrayElementType);
+ result.incrementDimension(1);
+ return result;
+ } else {
+ TypeVariable<?>[] tvs = clazz.getTypeParameters();
+ for (int i = 0; i < tvs.length; i++) {
+ if (tvs[i].equals(argType)) {
+ return new TypeResult(null, i, 0);
+ }
+ }
+ return null;
+ }
+ }
+
+
+ public static boolean isPrimitive(Class<?> clazz) {
+ if (clazz.isPrimitive()) {
+ return true;
+ } else if(clazz.equals(Boolean.class) ||
+ clazz.equals(Byte.class) ||
+ clazz.equals(Character.class) ||
+ clazz.equals(Double.class) ||
+ clazz.equals(Float.class) ||
+ clazz.equals(Integer.class) ||
+ clazz.equals(Long.class) ||
+ clazz.equals(Short.class)) {
+ return true;
+ }
+ return false;
+ }
+
+
+ public static Object coerceToType(Class<?> type, String value) {
+ if (type.equals(String.class)) {
+ return value;
+ } else if (type.equals(boolean.class) || type.equals(Boolean.class)) {
+ return Boolean.valueOf(value);
+ } else if (type.equals(byte.class) || type.equals(Byte.class)) {
+ return Byte.valueOf(value);
+ } else if (type.equals(char.class) || type.equals(Character.class)) {
+ return Character.valueOf(value.charAt(0));
+ } else if (type.equals(double.class) || type.equals(Double.class)) {
+ return Double.valueOf(value);
+ } else if (type.equals(float.class) || type.equals(Float.class)) {
+ return Float.valueOf(value);
+ } else if (type.equals(int.class) || type.equals(Integer.class)) {
+ return Integer.valueOf(value);
+ } else if (type.equals(long.class) || type.equals(Long.class)) {
+ return Long.valueOf(value);
+ } else if (type.equals(short.class) || type.equals(Short.class)) {
+ return Short.valueOf(value);
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "util.invalidType", value, type.getName()));
+ }
+ }
+
+
+ public static List<DecoderEntry> getDecoders(
+ List<Class<? extends Decoder>> decoderClazzes)
+ throws DeploymentException{
+
+ List<DecoderEntry> result = new ArrayList<>();
+ if (decoderClazzes != null) {
+ for (Class<? extends Decoder> decoderClazz : decoderClazzes) {
+ // Need to instantiate decoder to ensure it is valid and that
+ // deployment can be failed if it is not
+ @SuppressWarnings("unused")
+ Decoder instance;
+ try {
+ instance = decoderClazz.getConstructor().newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(
+ sm.getString("pojoMethodMapping.invalidDecoder",
+ decoderClazz.getName()), e);
+ }
+ DecoderEntry entry = new DecoderEntry(
+ Util.getDecoderType(decoderClazz), decoderClazz);
+ result.add(entry);
+ }
+ }
+
+ return result;
+ }
+
+
+ static Set<MessageHandlerResult> getMessageHandlers(Class<?> target,
+ MessageHandler listener, EndpointConfig endpointConfig,
+ Session session) {
+
+ // Will never be more than 2 types
+ Set<MessageHandlerResult> results = new HashSet<>(2);
+
+ // Simple cases - handlers already accepts one of the types expected by
+ // the frame handling code
+ if (String.class.isAssignableFrom(target)) {
+ MessageHandlerResult result =
+ new MessageHandlerResult(listener,
+ MessageHandlerResultType.TEXT);
+ results.add(result);
+ } else if (ByteBuffer.class.isAssignableFrom(target)) {
+ MessageHandlerResult result =
+ new MessageHandlerResult(listener,
+ MessageHandlerResultType.BINARY);
+ results.add(result);
+ } else if (PongMessage.class.isAssignableFrom(target)) {
+ MessageHandlerResult result =
+ new MessageHandlerResult(listener,
+ MessageHandlerResultType.PONG);
+ results.add(result);
+ // Handler needs wrapping and optional decoder to convert it to one of
+ // the types expected by the frame handling code
+ } else if (byte[].class.isAssignableFrom(target)) {
+ boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass());
+ MessageHandlerResult result = new MessageHandlerResult(
+ whole ? new PojoMessageHandlerWholeBinary(listener,
+ getOnMessageMethod(listener), session,
+ endpointConfig, matchDecoders(target, endpointConfig, true),
+ new Object[1], 0, true, -1, false, -1) :
+ new PojoMessageHandlerPartialBinary(listener,
+ getOnMessagePartialMethod(listener), session,
+ new Object[2], 0, true, 1, -1, -1),
+ MessageHandlerResultType.BINARY);
+ results.add(result);
+ } else if (InputStream.class.isAssignableFrom(target)) {
+ MessageHandlerResult result = new MessageHandlerResult(
+ new PojoMessageHandlerWholeBinary(listener,
+ getOnMessageMethod(listener), session,
+ endpointConfig, matchDecoders(target, endpointConfig, true),
+ new Object[1], 0, true, -1, true, -1),
+ MessageHandlerResultType.BINARY);
+ results.add(result);
+ } else if (Reader.class.isAssignableFrom(target)) {
+ MessageHandlerResult result = new MessageHandlerResult(
+ new PojoMessageHandlerWholeText(listener,
+ getOnMessageMethod(listener), session,
+ endpointConfig, matchDecoders(target, endpointConfig, false),
+ new Object[1], 0, true, -1, -1),
+ MessageHandlerResultType.TEXT);
+ results.add(result);
+ } else {
+ // Handler needs wrapping and requires decoder to convert it to one
+ // of the types expected by the frame handling code
+ DecoderMatch decoderMatch = matchDecoders(target, endpointConfig);
+ Method m = getOnMessageMethod(listener);
+ if (decoderMatch.getBinaryDecoders().size() > 0) {
+ MessageHandlerResult result = new MessageHandlerResult(
+ new PojoMessageHandlerWholeBinary(listener, m, session,
+ endpointConfig,
+ decoderMatch.getBinaryDecoders(), new Object[1],
+ 0, false, -1, false, -1),
+ MessageHandlerResultType.BINARY);
+ results.add(result);
+ }
+ if (decoderMatch.getTextDecoders().size() > 0) {
+ MessageHandlerResult result = new MessageHandlerResult(
+ new PojoMessageHandlerWholeText(listener, m, session,
+ endpointConfig,
+ decoderMatch.getTextDecoders(), new Object[1],
+ 0, false, -1, -1),
+ MessageHandlerResultType.TEXT);
+ results.add(result);
+ }
+ }
+
+ if (results.size() == 0) {
+ throw new IllegalArgumentException(
+ sm.getString("wsSession.unknownHandler", listener, target));
+ }
+
+ return results;
+ }
+
+ private static List<Class<? extends Decoder>> matchDecoders(Class<?> target,
+ EndpointConfig endpointConfig, boolean binary) {
+ DecoderMatch decoderMatch = matchDecoders(target, endpointConfig);
+ if (binary) {
+ if (decoderMatch.getBinaryDecoders().size() > 0) {
+ return decoderMatch.getBinaryDecoders();
+ }
+ } else if (decoderMatch.getTextDecoders().size() > 0) {
+ return decoderMatch.getTextDecoders();
+ }
+ return null;
+ }
+
+ private static DecoderMatch matchDecoders(Class<?> target,
+ EndpointConfig endpointConfig) {
+ DecoderMatch decoderMatch;
+ try {
+ List<Class<? extends Decoder>> decoders =
+ endpointConfig.getDecoders();
+ List<DecoderEntry> decoderEntries = getDecoders(decoders);
+ decoderMatch = new DecoderMatch(target, decoderEntries);
+ } catch (DeploymentException e) {
+ throw new IllegalArgumentException(e);
+ }
+ return decoderMatch;
+ }
+
+ public static void parseExtensionHeader(List<Extension> extensions,
+ String header) {
+ // The relevant ABNF for the Sec-WebSocket-Extensions is as follows:
+ // extension-list = 1#extension
+ // extension = extension-token *( ";" extension-param )
+ // extension-token = registered-token
+ // registered-token = token
+ // extension-param = token [ "=" (token | quoted-string) ]
+ // ; When using the quoted-string syntax variant, the value
+ // ; after quoted-string unescaping MUST conform to the
+ // ; 'token' ABNF.
+ //
+ // The limiting of parameter values to tokens or "quoted tokens" makes
+ // the parsing of the header significantly simpler and allows a number
+ // of short-cuts to be taken.
+
+ // Step one, split the header into individual extensions using ',' as a
+ // separator
+ String unparsedExtensions[] = header.split(",");
+ for (String unparsedExtension : unparsedExtensions) {
+ // Step two, split the extension into the registered name and
+ // parameter/value pairs using ';' as a separator
+ String unparsedParameters[] = unparsedExtension.split(";");
+ WsExtension extension = new WsExtension(unparsedParameters[0].trim());
+
+ for (int i = 1; i < unparsedParameters.length; i++) {
+ int equalsPos = unparsedParameters[i].indexOf('=');
+ String name;
+ String value;
+ if (equalsPos == -1) {
+ name = unparsedParameters[i].trim();
+ value = null;
+ } else {
+ name = unparsedParameters[i].substring(0, equalsPos).trim();
+ value = unparsedParameters[i].substring(equalsPos + 1).trim();
+ int len = value.length();
+ if (len > 1) {
+ if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') {
+ value = value.substring(1, value.length() - 1);
+ }
+ }
+ }
+ // Make sure value doesn't contain any of the delimiters since
+ // that would indicate something went wrong
+ if (containsDelims(name) || containsDelims(value)) {
+ throw new IllegalArgumentException(sm.getString(
+ "util.notToken", name, value));
+ }
+ if (value != null &&
+ (value.indexOf(',') > -1 || value.indexOf(';') > -1 ||
+ value.indexOf('\"') > -1 || value.indexOf('=') > -1)) {
+ throw new IllegalArgumentException(sm.getString("", value));
+ }
+ extension.addParameter(new WsExtensionParameter(name, value));
+ }
+ extensions.add(extension);
+ }
+ }
+
+
+ private static boolean containsDelims(String input) {
+ if (input == null || input.length() == 0) {
+ return false;
+ }
+ for (char c : input.toCharArray()) {
+ switch (c) {
+ case ',':
+ case ';':
+ case '\"':
+ case '=':
+ return true;
+ default:
+ // NO_OP
+ }
+
+ }
+ return false;
+ }
+
+ private static Method getOnMessageMethod(MessageHandler listener) {
+ try {
+ return listener.getClass().getMethod("onMessage", Object.class);
+ } catch (NoSuchMethodException | SecurityException e) {
+ throw new IllegalArgumentException(
+ sm.getString("util.invalidMessageHandler"), e);
+ }
+ }
+
+ private static Method getOnMessagePartialMethod(MessageHandler listener) {
+ try {
+ return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE);
+ } catch (NoSuchMethodException | SecurityException e) {
+ throw new IllegalArgumentException(
+ sm.getString("util.invalidMessageHandler"), e);
+ }
+ }
+
+
+ public static class DecoderMatch {
+
+ private final List<Class<? extends Decoder>> textDecoders =
+ new ArrayList<>();
+ private final List<Class<? extends Decoder>> binaryDecoders =
+ new ArrayList<>();
+ private final Class<?> target;
+
+ public DecoderMatch(Class<?> target, List<DecoderEntry> decoderEntries) {
+ this.target = target;
+ for (DecoderEntry decoderEntry : decoderEntries) {
+ if (decoderEntry.getClazz().isAssignableFrom(target)) {
+ if (Binary.class.isAssignableFrom(
+ decoderEntry.getDecoderClazz())) {
+ binaryDecoders.add(decoderEntry.getDecoderClazz());
+ // willDecode() method means this decoder may or may not
+ // decode a message so need to carry on checking for
+ // other matches
+ } else if (BinaryStream.class.isAssignableFrom(
+ decoderEntry.getDecoderClazz())) {
+ binaryDecoders.add(decoderEntry.getDecoderClazz());
+ // Stream decoders have to process the message so no
+ // more decoders can be matched
+ break;
+ } else if (Text.class.isAssignableFrom(
+ decoderEntry.getDecoderClazz())) {
+ textDecoders.add(decoderEntry.getDecoderClazz());
+ // willDecode() method means this decoder may or may not
+ // decode a message so need to carry on checking for
+ // other matches
+ } else if (TextStream.class.isAssignableFrom(
+ decoderEntry.getDecoderClazz())) {
+ textDecoders.add(decoderEntry.getDecoderClazz());
+ // Stream decoders have to process the message so no
+ // more decoders can be matched
+ break;
+ } else {
+ throw new IllegalArgumentException(
+ sm.getString("util.unknownDecoderType"));
+ }
+ }
+ }
+ }
+
+
+ public List<Class<? extends Decoder>> getTextDecoders() {
+ return textDecoders;
+ }
+
+
+ public List<Class<? extends Decoder>> getBinaryDecoders() {
+ return binaryDecoders;
+ }
+
+
+ public Class<?> getTarget() {
+ return target;
+ }
+
+
+ public boolean hasMatches() {
+ return (textDecoders.size() > 0) || (binaryDecoders.size() > 0);
+ }
+ }
+
+
+ private static class TypeResult {
+ private final Class<?> clazz;
+ private final int index;
+ private int dimension;
+
+ public TypeResult(Class<?> clazz, int index, int dimension) {
+ this.clazz= clazz;
+ this.index = index;
+ this.dimension = dimension;
+ }
+
+ public Class<?> getClazz() {
+ return clazz;
+ }
+
+ public int getIndex() {
+ return index;
+ }
+
+ public int getDimension() {
+ return dimension;
+ }
+
+ public void incrementDimension(int inc) {
+ dimension += inc;
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WrappedMessageHandler.java b/src/java/nginx/unit/websocket/WrappedMessageHandler.java
new file mode 100644
index 00000000..2557a73e
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WrappedMessageHandler.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import javax.websocket.MessageHandler;
+
+public interface WrappedMessageHandler {
+ long getMaxMessageSize();
+
+ MessageHandler getWrappedHandler();
+}
diff --git a/src/java/nginx/unit/websocket/WsContainerProvider.java b/src/java/nginx/unit/websocket/WsContainerProvider.java
new file mode 100644
index 00000000..f8a404a1
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsContainerProvider.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import javax.websocket.ContainerProvider;
+import javax.websocket.WebSocketContainer;
+
+public class WsContainerProvider extends ContainerProvider {
+
+ @Override
+ protected WebSocketContainer getContainer() {
+ return new WsWebSocketContainer();
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsExtension.java b/src/java/nginx/unit/websocket/WsExtension.java
new file mode 100644
index 00000000..3846feb1
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsExtension.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import javax.websocket.Extension;
+
+public class WsExtension implements Extension {
+
+ private final String name;
+ private final List<Parameter> parameters = new ArrayList<>();
+
+ WsExtension(String name) {
+ this.name = name;
+ }
+
+ void addParameter(Parameter parameter) {
+ parameters.add(parameter);
+ }
+
+ @Override
+ public String getName() {
+ return name;
+ }
+
+ @Override
+ public List<Parameter> getParameters() {
+ return parameters;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsExtensionParameter.java b/src/java/nginx/unit/websocket/WsExtensionParameter.java
new file mode 100644
index 00000000..9b82f1c7
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsExtensionParameter.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import javax.websocket.Extension.Parameter;
+
+public class WsExtensionParameter implements Parameter {
+
+ private final String name;
+ private final String value;
+
+ WsExtensionParameter(String name, String value) {
+ this.name = name;
+ this.value = value;
+ }
+
+ @Override
+ public String getName() {
+ return name;
+ }
+
+ @Override
+ public String getValue() {
+ return value;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsFrameBase.java b/src/java/nginx/unit/websocket/WsFrameBase.java
new file mode 100644
index 00000000..06d20bf4
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsFrameBase.java
@@ -0,0 +1,1010 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.charset.CharsetDecoder;
+import java.nio.charset.CoderResult;
+import java.nio.charset.CodingErrorAction;
+import java.util.List;
+import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
+
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.Extension;
+import javax.websocket.MessageHandler;
+import javax.websocket.PongMessage;
+
+import org.apache.juli.logging.Log;
+import org.apache.tomcat.util.ExceptionUtils;
+import org.apache.tomcat.util.buf.Utf8Decoder;
+import org.apache.tomcat.util.res.StringManager;
+
+/**
+ * Takes the ServletInputStream, processes the WebSocket frames it contains and
+ * extracts the messages. WebSocket Pings received will be responded to
+ * automatically without any action required by the application.
+ */
+public abstract class WsFrameBase {
+
+ private static final StringManager sm = StringManager.getManager(WsFrameBase.class);
+
+ // Connection level attributes
+ protected final WsSession wsSession;
+ protected final ByteBuffer inputBuffer;
+ private final Transformation transformation;
+
+ // Attributes for control messages
+ // Control messages can appear in the middle of other messages so need
+ // separate attributes
+ private final ByteBuffer controlBufferBinary = ByteBuffer.allocate(125);
+ private final CharBuffer controlBufferText = CharBuffer.allocate(125);
+
+ // Attributes of the current message
+ private final CharsetDecoder utf8DecoderControl = new Utf8Decoder().
+ onMalformedInput(CodingErrorAction.REPORT).
+ onUnmappableCharacter(CodingErrorAction.REPORT);
+ private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder().
+ onMalformedInput(CodingErrorAction.REPORT).
+ onUnmappableCharacter(CodingErrorAction.REPORT);
+ private boolean continuationExpected = false;
+ private boolean textMessage = false;
+ private ByteBuffer messageBufferBinary;
+ private CharBuffer messageBufferText;
+ // Cache the message handler in force when the message starts so it is used
+ // consistently for the entire message
+ private MessageHandler binaryMsgHandler = null;
+ private MessageHandler textMsgHandler = null;
+
+ // Attributes of the current frame
+ private boolean fin = false;
+ private int rsv = 0;
+ private byte opCode = 0;
+ private final byte[] mask = new byte[4];
+ private int maskIndex = 0;
+ private long payloadLength = 0;
+ private volatile long payloadWritten = 0;
+
+ // Attributes tracking state
+ private volatile State state = State.NEW_FRAME;
+ private volatile boolean open = true;
+
+ private static final AtomicReferenceFieldUpdater<WsFrameBase, ReadState> READ_STATE_UPDATER =
+ AtomicReferenceFieldUpdater.newUpdater(WsFrameBase.class, ReadState.class, "readState");
+ private volatile ReadState readState = ReadState.WAITING;
+
+ public WsFrameBase(WsSession wsSession, Transformation transformation) {
+ inputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ inputBuffer.position(0).limit(0);
+ messageBufferBinary = ByteBuffer.allocate(wsSession.getMaxBinaryMessageBufferSize());
+ messageBufferText = CharBuffer.allocate(wsSession.getMaxTextMessageBufferSize());
+ this.wsSession = wsSession;
+ Transformation finalTransformation;
+ if (isMasked()) {
+ finalTransformation = new UnmaskTransformation();
+ } else {
+ finalTransformation = new NoopTransformation();
+ }
+ if (transformation == null) {
+ this.transformation = finalTransformation;
+ } else {
+ transformation.setNext(finalTransformation);
+ this.transformation = transformation;
+ }
+ }
+
+
+ protected void processInputBuffer() throws IOException {
+ while (!isSuspended()) {
+ wsSession.updateLastActive();
+ if (state == State.NEW_FRAME) {
+ if (!processInitialHeader()) {
+ break;
+ }
+ // If a close frame has been received, no further data should
+ // have seen
+ if (!open) {
+ throw new IOException(sm.getString("wsFrame.closed"));
+ }
+ }
+ if (state == State.PARTIAL_HEADER) {
+ if (!processRemainingHeader()) {
+ break;
+ }
+ }
+ if (state == State.DATA) {
+ if (!processData()) {
+ break;
+ }
+ }
+ }
+ }
+
+
+ /**
+ * @return <code>true</code> if sufficient data was present to process all
+ * of the initial header
+ */
+ private boolean processInitialHeader() throws IOException {
+ // Need at least two bytes of data to do this
+ if (inputBuffer.remaining() < 2) {
+ return false;
+ }
+ int b = inputBuffer.get();
+ fin = (b & 0x80) != 0;
+ rsv = (b & 0x70) >>> 4;
+ opCode = (byte) (b & 0x0F);
+ if (!transformation.validateRsv(rsv, opCode)) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.wrongRsv", Integer.valueOf(rsv), Integer.valueOf(opCode))));
+ }
+
+ if (Util.isControl(opCode)) {
+ if (!fin) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.controlFragmented")));
+ }
+ if (opCode != Constants.OPCODE_PING &&
+ opCode != Constants.OPCODE_PONG &&
+ opCode != Constants.OPCODE_CLOSE) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode))));
+ }
+ } else {
+ if (continuationExpected) {
+ if (!Util.isContinuation(opCode)) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.noContinuation")));
+ }
+ } else {
+ try {
+ if (opCode == Constants.OPCODE_BINARY) {
+ // New binary message
+ textMessage = false;
+ int size = wsSession.getMaxBinaryMessageBufferSize();
+ if (size != messageBufferBinary.capacity()) {
+ messageBufferBinary = ByteBuffer.allocate(size);
+ }
+ binaryMsgHandler = wsSession.getBinaryMessageHandler();
+ textMsgHandler = null;
+ } else if (opCode == Constants.OPCODE_TEXT) {
+ // New text message
+ textMessage = true;
+ int size = wsSession.getMaxTextMessageBufferSize();
+ if (size != messageBufferText.capacity()) {
+ messageBufferText = CharBuffer.allocate(size);
+ }
+ binaryMsgHandler = null;
+ textMsgHandler = wsSession.getTextMessageHandler();
+ } else {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode))));
+ }
+ } catch (IllegalStateException ise) {
+ // Thrown if the session is already closed
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.sessionClosed")));
+ }
+ }
+ continuationExpected = !fin;
+ }
+ b = inputBuffer.get();
+ // Client data must be masked
+ if ((b & 0x80) == 0 && isMasked()) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.notMasked")));
+ }
+ payloadLength = b & 0x7F;
+ state = State.PARTIAL_HEADER;
+ if (getLog().isDebugEnabled()) {
+ getLog().debug(sm.getString("wsFrame.partialHeaderComplete", Boolean.toString(fin),
+ Integer.toString(rsv), Integer.toString(opCode), Long.toString(payloadLength)));
+ }
+ return true;
+ }
+
+
+ protected abstract boolean isMasked();
+ protected abstract Log getLog();
+
+
+ /**
+ * @return <code>true</code> if sufficient data was present to complete the
+ * processing of the header
+ */
+ private boolean processRemainingHeader() throws IOException {
+ // Ignore the 2 bytes already read. 4 for the mask
+ int headerLength;
+ if (isMasked()) {
+ headerLength = 4;
+ } else {
+ headerLength = 0;
+ }
+ // Add additional bytes depending on length
+ if (payloadLength == 126) {
+ headerLength += 2;
+ } else if (payloadLength == 127) {
+ headerLength += 8;
+ }
+ if (inputBuffer.remaining() < headerLength) {
+ return false;
+ }
+ // Calculate new payload length if necessary
+ if (payloadLength == 126) {
+ payloadLength = byteArrayToLong(inputBuffer.array(),
+ inputBuffer.arrayOffset() + inputBuffer.position(), 2);
+ inputBuffer.position(inputBuffer.position() + 2);
+ } else if (payloadLength == 127) {
+ payloadLength = byteArrayToLong(inputBuffer.array(),
+ inputBuffer.arrayOffset() + inputBuffer.position(), 8);
+ inputBuffer.position(inputBuffer.position() + 8);
+ }
+ if (Util.isControl(opCode)) {
+ if (payloadLength > 125) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.controlPayloadTooBig", Long.valueOf(payloadLength))));
+ }
+ if (!fin) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.controlNoFin")));
+ }
+ }
+ if (isMasked()) {
+ inputBuffer.get(mask, 0, 4);
+ }
+ state = State.DATA;
+ return true;
+ }
+
+
+ private boolean processData() throws IOException {
+ boolean result;
+ if (Util.isControl(opCode)) {
+ result = processDataControl();
+ } else if (textMessage) {
+ if (textMsgHandler == null) {
+ result = swallowInput();
+ } else {
+ result = processDataText();
+ }
+ } else {
+ if (binaryMsgHandler == null) {
+ result = swallowInput();
+ } else {
+ result = processDataBinary();
+ }
+ }
+ checkRoomPayload();
+ return result;
+ }
+
+
+ private boolean processDataControl() throws IOException {
+ TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, controlBufferBinary);
+ if (TransformationResult.UNDERFLOW.equals(tr)) {
+ return false;
+ }
+ // Control messages have fixed message size so
+ // TransformationResult.OVERFLOW is not possible here
+
+ controlBufferBinary.flip();
+ if (opCode == Constants.OPCODE_CLOSE) {
+ open = false;
+ String reason = null;
+ int code = CloseCodes.NORMAL_CLOSURE.getCode();
+ if (controlBufferBinary.remaining() == 1) {
+ controlBufferBinary.clear();
+ // Payload must be zero or 2+ bytes long
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.oneByteCloseCode")));
+ }
+ if (controlBufferBinary.remaining() > 1) {
+ code = controlBufferBinary.getShort();
+ if (controlBufferBinary.remaining() > 0) {
+ CoderResult cr = utf8DecoderControl.decode(controlBufferBinary,
+ controlBufferText, true);
+ if (cr.isError()) {
+ controlBufferBinary.clear();
+ controlBufferText.clear();
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.invalidUtf8Close")));
+ }
+ // There will be no overflow as the output buffer is big
+ // enough. There will be no underflow as all the data is
+ // passed to the decoder in a single call.
+ controlBufferText.flip();
+ reason = controlBufferText.toString();
+ }
+ }
+ wsSession.onClose(new CloseReason(Util.getCloseCode(code), reason));
+ } else if (opCode == Constants.OPCODE_PING) {
+ if (wsSession.isOpen()) {
+ wsSession.getBasicRemote().sendPong(controlBufferBinary);
+ }
+ } else if (opCode == Constants.OPCODE_PONG) {
+ MessageHandler.Whole<PongMessage> mhPong = wsSession.getPongMessageHandler();
+ if (mhPong != null) {
+ try {
+ mhPong.onMessage(new WsPongMessage(controlBufferBinary));
+ } catch (Throwable t) {
+ handleThrowableOnSend(t);
+ } finally {
+ controlBufferBinary.clear();
+ }
+ }
+ } else {
+ // Should have caught this earlier but just in case...
+ controlBufferBinary.clear();
+ throw new WsIOException(new CloseReason(
+ CloseCodes.PROTOCOL_ERROR,
+ sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode))));
+ }
+ controlBufferBinary.clear();
+ newFrame();
+ return true;
+ }
+
+
+ @SuppressWarnings("unchecked")
+ protected void sendMessageText(boolean last) throws WsIOException {
+ if (textMsgHandler instanceof WrappedMessageHandler) {
+ long maxMessageSize = ((WrappedMessageHandler) textMsgHandler).getMaxMessageSize();
+ if (maxMessageSize > -1 && messageBufferText.remaining() > maxMessageSize) {
+ throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG,
+ sm.getString("wsFrame.messageTooBig",
+ Long.valueOf(messageBufferText.remaining()),
+ Long.valueOf(maxMessageSize))));
+ }
+ }
+
+ try {
+ if (textMsgHandler instanceof MessageHandler.Partial<?>) {
+ ((MessageHandler.Partial<String>) textMsgHandler)
+ .onMessage(messageBufferText.toString(), last);
+ } else {
+ // Caller ensures last == true if this branch is used
+ ((MessageHandler.Whole<String>) textMsgHandler)
+ .onMessage(messageBufferText.toString());
+ }
+ } catch (Throwable t) {
+ handleThrowableOnSend(t);
+ } finally {
+ messageBufferText.clear();
+ }
+ }
+
+
+ private boolean processDataText() throws IOException {
+ // Copy the available data to the buffer
+ TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary);
+ while (!TransformationResult.END_OF_FRAME.equals(tr)) {
+ // Frame not complete - we ran out of something
+ // Convert bytes to UTF-8
+ messageBufferBinary.flip();
+ while (true) {
+ CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText,
+ false);
+ if (cr.isError()) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.NOT_CONSISTENT,
+ sm.getString("wsFrame.invalidUtf8")));
+ } else if (cr.isOverflow()) {
+ // Ran out of space in text buffer - flush it
+ if (usePartial()) {
+ messageBufferText.flip();
+ sendMessageText(false);
+ messageBufferText.clear();
+ } else {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.TOO_BIG,
+ sm.getString("wsFrame.textMessageTooBig")));
+ }
+ } else if (cr.isUnderflow()) {
+ // Compact what we have to create as much space as possible
+ messageBufferBinary.compact();
+
+ // Need more input
+ // What did we run out of?
+ if (TransformationResult.OVERFLOW.equals(tr)) {
+ // Ran out of message buffer - exit inner loop and
+ // refill
+ break;
+ } else {
+ // TransformationResult.UNDERFLOW
+ // Ran out of input data - get some more
+ return false;
+ }
+ }
+ }
+ // Read more input data
+ tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary);
+ }
+
+ messageBufferBinary.flip();
+ boolean last = false;
+ // Frame is fully received
+ // Convert bytes to UTF-8
+ while (true) {
+ CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText,
+ last);
+ if (cr.isError()) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.NOT_CONSISTENT,
+ sm.getString("wsFrame.invalidUtf8")));
+ } else if (cr.isOverflow()) {
+ // Ran out of space in text buffer - flush it
+ if (usePartial()) {
+ messageBufferText.flip();
+ sendMessageText(false);
+ messageBufferText.clear();
+ } else {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.TOO_BIG,
+ sm.getString("wsFrame.textMessageTooBig")));
+ }
+ } else if (cr.isUnderflow() && !last) {
+ // End of frame and possible message as well.
+
+ if (continuationExpected) {
+ // If partial messages are supported, send what we have
+ // managed to decode
+ if (usePartial()) {
+ messageBufferText.flip();
+ sendMessageText(false);
+ messageBufferText.clear();
+ }
+ messageBufferBinary.compact();
+ newFrame();
+ // Process next frame
+ return true;
+ } else {
+ // Make sure coder has flushed all output
+ last = true;
+ }
+ } else {
+ // End of message
+ messageBufferText.flip();
+ sendMessageText(true);
+
+ newMessage();
+ return true;
+ }
+ }
+ }
+
+
+ private boolean processDataBinary() throws IOException {
+ // Copy the available data to the buffer
+ TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary);
+ while (!TransformationResult.END_OF_FRAME.equals(tr)) {
+ // Frame not complete - what did we run out of?
+ if (TransformationResult.UNDERFLOW.equals(tr)) {
+ // Ran out of input data - get some more
+ return false;
+ }
+
+ // Ran out of message buffer - flush it
+ if (!usePartial()) {
+ CloseReason cr = new CloseReason(CloseCodes.TOO_BIG,
+ sm.getString("wsFrame.bufferTooSmall",
+ Integer.valueOf(messageBufferBinary.capacity()),
+ Long.valueOf(payloadLength)));
+ throw new WsIOException(cr);
+ }
+ messageBufferBinary.flip();
+ ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit());
+ copy.put(messageBufferBinary);
+ copy.flip();
+ sendMessageBinary(copy, false);
+ messageBufferBinary.clear();
+ // Read more data
+ tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary);
+ }
+
+ // Frame is fully received
+ // Send the message if either:
+ // - partial messages are supported
+ // - the message is complete
+ if (usePartial() || !continuationExpected) {
+ messageBufferBinary.flip();
+ ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit());
+ copy.put(messageBufferBinary);
+ copy.flip();
+ sendMessageBinary(copy, !continuationExpected);
+ messageBufferBinary.clear();
+ }
+
+ if (continuationExpected) {
+ // More data for this message expected, start a new frame
+ newFrame();
+ } else {
+ // Message is complete, start a new message
+ newMessage();
+ }
+
+ return true;
+ }
+
+
+ private void handleThrowableOnSend(Throwable t) throws WsIOException {
+ ExceptionUtils.handleThrowable(t);
+ wsSession.getLocal().onError(wsSession, t);
+ CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY,
+ sm.getString("wsFrame.ioeTriggeredClose"));
+ throw new WsIOException(cr);
+ }
+
+
+ @SuppressWarnings("unchecked")
+ protected void sendMessageBinary(ByteBuffer msg, boolean last) throws WsIOException {
+ if (binaryMsgHandler instanceof WrappedMessageHandler) {
+ long maxMessageSize = ((WrappedMessageHandler) binaryMsgHandler).getMaxMessageSize();
+ if (maxMessageSize > -1 && msg.remaining() > maxMessageSize) {
+ throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG,
+ sm.getString("wsFrame.messageTooBig",
+ Long.valueOf(msg.remaining()),
+ Long.valueOf(maxMessageSize))));
+ }
+ }
+ try {
+ if (binaryMsgHandler instanceof MessageHandler.Partial<?>) {
+ ((MessageHandler.Partial<ByteBuffer>) binaryMsgHandler).onMessage(msg, last);
+ } else {
+ // Caller ensures last == true if this branch is used
+ ((MessageHandler.Whole<ByteBuffer>) binaryMsgHandler).onMessage(msg);
+ }
+ } catch (Throwable t) {
+ handleThrowableOnSend(t);
+ }
+ }
+
+
+ private void newMessage() {
+ messageBufferBinary.clear();
+ messageBufferText.clear();
+ utf8DecoderMessage.reset();
+ continuationExpected = false;
+ newFrame();
+ }
+
+
+ private void newFrame() {
+ if (inputBuffer.remaining() == 0) {
+ inputBuffer.position(0).limit(0);
+ }
+
+ maskIndex = 0;
+ payloadWritten = 0;
+ state = State.NEW_FRAME;
+
+ // These get reset in processInitialHeader()
+ // fin, rsv, opCode, payloadLength, mask
+
+ checkRoomHeaders();
+ }
+
+
+ private void checkRoomHeaders() {
+ // Is the start of the current frame too near the end of the input
+ // buffer?
+ if (inputBuffer.capacity() - inputBuffer.position() < 131) {
+ // Limit based on a control frame with a full payload
+ makeRoom();
+ }
+ }
+
+
+ private void checkRoomPayload() {
+ if (inputBuffer.capacity() - inputBuffer.position() - payloadLength + payloadWritten < 0) {
+ makeRoom();
+ }
+ }
+
+
+ private void makeRoom() {
+ inputBuffer.compact();
+ inputBuffer.flip();
+ }
+
+
+ private boolean usePartial() {
+ if (Util.isControl(opCode)) {
+ return false;
+ } else if (textMessage) {
+ return textMsgHandler instanceof MessageHandler.Partial;
+ } else {
+ // Must be binary
+ return binaryMsgHandler instanceof MessageHandler.Partial;
+ }
+ }
+
+
+ private boolean swallowInput() {
+ long toSkip = Math.min(payloadLength - payloadWritten, inputBuffer.remaining());
+ inputBuffer.position(inputBuffer.position() + (int) toSkip);
+ payloadWritten += toSkip;
+ if (payloadWritten == payloadLength) {
+ if (continuationExpected) {
+ newFrame();
+ } else {
+ newMessage();
+ }
+ return true;
+ } else {
+ return false;
+ }
+ }
+
+
+ protected static long byteArrayToLong(byte[] b, int start, int len) throws IOException {
+ if (len > 8) {
+ throw new IOException(sm.getString("wsFrame.byteToLongFail", Long.valueOf(len)));
+ }
+ int shift = 0;
+ long result = 0;
+ for (int i = start + len - 1; i >= start; i--) {
+ result = result + ((b[i] & 0xFF) << shift);
+ shift += 8;
+ }
+ return result;
+ }
+
+
+ protected boolean isOpen() {
+ return open;
+ }
+
+
+ protected Transformation getTransformation() {
+ return transformation;
+ }
+
+
+ private enum State {
+ NEW_FRAME, PARTIAL_HEADER, DATA
+ }
+
+
+ /**
+ * WAITING - not suspended
+ * Server case: waiting for a notification that data
+ * is ready to be read from the socket, the socket is
+ * registered to the poller
+ * Client case: data has been read from the socket and
+ * is waiting for data to be processed
+ * PROCESSING - not suspended
+ * Server case: reading from the socket and processing
+ * the data
+ * Client case: processing the data if such has
+ * already been read and more data will be read from
+ * the socket
+ * SUSPENDING_WAIT - suspended, a call to suspend() was made while in
+ * WAITING state. A call to resume() will do nothing
+ * and will transition to WAITING state
+ * SUSPENDING_PROCESS - suspended, a call to suspend() was made while in
+ * PROCESSING state. A call to resume() will do
+ * nothing and will transition to PROCESSING state
+ * SUSPENDED - suspended
+ * Server case: processing data finished
+ * (SUSPENDING_PROCESS) / a notification was received
+ * that data is ready to be read from the socket
+ * (SUSPENDING_WAIT), socket is not registered to the
+ * poller
+ * Client case: processing data finished
+ * (SUSPENDING_PROCESS) / data has been read from the
+ * socket and is available for processing
+ * (SUSPENDING_WAIT)
+ * A call to resume() will:
+ * Server case: register the socket to the poller
+ * Client case: resume data processing
+ * CLOSING - not suspended, a close will be send
+ *
+ * <pre>
+ * resume data to be resume
+ * no action processed no action
+ * |---------------| |---------------| |----------|
+ * | v | v v |
+ * | |----------WAITING --------PROCESSING----| |
+ * | | ^ processing | |
+ * | | | finished | |
+ * | | | | |
+ * | suspend | suspend |
+ * | | | | |
+ * | | resume | |
+ * | | register socket to poller (server) | |
+ * | | resume data processing (client) | |
+ * | | | | |
+ * | v | v |
+ * SUSPENDING_WAIT | SUSPENDING_PROCESS
+ * | | |
+ * | data available | processing finished |
+ * |------------- SUSPENDED ----------------------|
+ * </pre>
+ */
+ protected enum ReadState {
+ WAITING (false),
+ PROCESSING (false),
+ SUSPENDING_WAIT (true),
+ SUSPENDING_PROCESS(true),
+ SUSPENDED (true),
+ CLOSING (false);
+
+ private final boolean isSuspended;
+
+ ReadState(boolean isSuspended) {
+ this.isSuspended = isSuspended;
+ }
+
+ public boolean isSuspended() {
+ return isSuspended;
+ }
+ }
+
+ public void suspend() {
+ while (true) {
+ switch (readState) {
+ case WAITING:
+ if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.WAITING,
+ ReadState.SUSPENDING_WAIT)) {
+ continue;
+ }
+ return;
+ case PROCESSING:
+ if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.PROCESSING,
+ ReadState.SUSPENDING_PROCESS)) {
+ continue;
+ }
+ return;
+ case SUSPENDING_WAIT:
+ if (readState != ReadState.SUSPENDING_WAIT) {
+ continue;
+ } else {
+ if (getLog().isWarnEnabled()) {
+ getLog().warn(sm.getString("wsFrame.suspendRequested"));
+ }
+ }
+ return;
+ case SUSPENDING_PROCESS:
+ if (readState != ReadState.SUSPENDING_PROCESS) {
+ continue;
+ } else {
+ if (getLog().isWarnEnabled()) {
+ getLog().warn(sm.getString("wsFrame.suspendRequested"));
+ }
+ }
+ return;
+ case SUSPENDED:
+ if (readState != ReadState.SUSPENDED) {
+ continue;
+ } else {
+ if (getLog().isWarnEnabled()) {
+ getLog().warn(sm.getString("wsFrame.alreadySuspended"));
+ }
+ }
+ return;
+ case CLOSING:
+ return;
+ default:
+ throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state));
+ }
+ }
+ }
+
+ public void resume() {
+ while (true) {
+ switch (readState) {
+ case WAITING:
+ if (readState != ReadState.WAITING) {
+ continue;
+ } else {
+ if (getLog().isWarnEnabled()) {
+ getLog().warn(sm.getString("wsFrame.alreadyResumed"));
+ }
+ }
+ return;
+ case PROCESSING:
+ if (readState != ReadState.PROCESSING) {
+ continue;
+ } else {
+ if (getLog().isWarnEnabled()) {
+ getLog().warn(sm.getString("wsFrame.alreadyResumed"));
+ }
+ }
+ return;
+ case SUSPENDING_WAIT:
+ if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_WAIT,
+ ReadState.WAITING)) {
+ continue;
+ }
+ return;
+ case SUSPENDING_PROCESS:
+ if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_PROCESS,
+ ReadState.PROCESSING)) {
+ continue;
+ }
+ return;
+ case SUSPENDED:
+ if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDED,
+ ReadState.WAITING)) {
+ continue;
+ }
+ resumeProcessing();
+ return;
+ case CLOSING:
+ return;
+ default:
+ throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state));
+ }
+ }
+ }
+
+ protected boolean isSuspended() {
+ return readState.isSuspended();
+ }
+
+ protected ReadState getReadState() {
+ return readState;
+ }
+
+ protected void changeReadState(ReadState newState) {
+ READ_STATE_UPDATER.set(this, newState);
+ }
+
+ protected boolean changeReadState(ReadState oldState, ReadState newState) {
+ return READ_STATE_UPDATER.compareAndSet(this, oldState, newState);
+ }
+
+ /**
+ * This method will be invoked when the read operation is resumed.
+ * As the suspend of the read operation can be invoked at any time, when
+ * implementing this method one should consider that there might still be
+ * data remaining into the internal buffers that needs to be processed
+ * before reading again from the socket.
+ */
+ protected abstract void resumeProcessing();
+
+
+ private abstract class TerminalTransformation implements Transformation {
+
+ @Override
+ public boolean validateRsvBits(int i) {
+ // Terminal transformations don't use RSV bits and there is no next
+ // transformation so always return true.
+ return true;
+ }
+
+ @Override
+ public Extension getExtensionResponse() {
+ // Return null since terminal transformations are not extensions
+ return null;
+ }
+
+ @Override
+ public void setNext(Transformation t) {
+ // NO-OP since this is the terminal transformation
+ }
+
+ /**
+ * {@inheritDoc}
+ * <p>
+ * Anything other than a value of zero for rsv is invalid.
+ */
+ @Override
+ public boolean validateRsv(int rsv, byte opCode) {
+ return rsv == 0;
+ }
+
+ @Override
+ public void close() {
+ // NO-OP for the terminal transformations
+ }
+ }
+
+
+ /**
+ * For use by the client implementation that needs to obtain payload data
+ * without the need for unmasking.
+ */
+ private final class NoopTransformation extends TerminalTransformation {
+
+ @Override
+ public TransformationResult getMoreData(byte opCode, boolean fin, int rsv,
+ ByteBuffer dest) {
+ // opCode is ignored as the transformation is the same for all
+ // opCodes
+ // rsv is ignored as it known to be zero at this point
+ long toWrite = Math.min(payloadLength - payloadWritten, inputBuffer.remaining());
+ toWrite = Math.min(toWrite, dest.remaining());
+
+ int orgLimit = inputBuffer.limit();
+ inputBuffer.limit(inputBuffer.position() + (int) toWrite);
+ dest.put(inputBuffer);
+ inputBuffer.limit(orgLimit);
+ payloadWritten += toWrite;
+
+ if (payloadWritten == payloadLength) {
+ return TransformationResult.END_OF_FRAME;
+ } else if (inputBuffer.remaining() == 0) {
+ return TransformationResult.UNDERFLOW;
+ } else {
+ // !dest.hasRemaining()
+ return TransformationResult.OVERFLOW;
+ }
+ }
+
+
+ @Override
+ public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) {
+ // TODO Masking should move to this method
+ // NO-OP send so simply return the message unchanged.
+ return messageParts;
+ }
+ }
+
+
+ /**
+ * For use by the server implementation that needs to obtain payload data
+ * and unmask it before any further processing.
+ */
+ private final class UnmaskTransformation extends TerminalTransformation {
+
+ @Override
+ public TransformationResult getMoreData(byte opCode, boolean fin, int rsv,
+ ByteBuffer dest) {
+ // opCode is ignored as the transformation is the same for all
+ // opCodes
+ // rsv is ignored as it known to be zero at this point
+ while (payloadWritten < payloadLength && inputBuffer.remaining() > 0 &&
+ dest.hasRemaining()) {
+ byte b = (byte) ((inputBuffer.get() ^ mask[maskIndex]) & 0xFF);
+ maskIndex++;
+ if (maskIndex == 4) {
+ maskIndex = 0;
+ }
+ payloadWritten++;
+ dest.put(b);
+ }
+ if (payloadWritten == payloadLength) {
+ return TransformationResult.END_OF_FRAME;
+ } else if (inputBuffer.remaining() == 0) {
+ return TransformationResult.UNDERFLOW;
+ } else {
+ // !dest.hasRemaining()
+ return TransformationResult.OVERFLOW;
+ }
+ }
+
+ @Override
+ public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) {
+ // NO-OP send so simply return the message unchanged.
+ return messageParts;
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsFrameClient.java b/src/java/nginx/unit/websocket/WsFrameClient.java
new file mode 100644
index 00000000..3174c766
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsFrameClient.java
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.nio.channels.CompletionHandler;
+
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCodes;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.util.res.StringManager;
+
+public class WsFrameClient extends WsFrameBase {
+
+ private final Log log = LogFactory.getLog(WsFrameClient.class); // must not be static
+ private static final StringManager sm = StringManager.getManager(WsFrameClient.class);
+
+ private final AsyncChannelWrapper channel;
+ private final CompletionHandler<Integer, Void> handler;
+ // Not final as it may need to be re-sized
+ private volatile ByteBuffer response;
+
+ public WsFrameClient(ByteBuffer response, AsyncChannelWrapper channel, WsSession wsSession,
+ Transformation transformation) {
+ super(wsSession, transformation);
+ this.response = response;
+ this.channel = channel;
+ this.handler = new WsFrameClientCompletionHandler();
+ }
+
+
+ void startInputProcessing() {
+ try {
+ processSocketRead();
+ } catch (IOException e) {
+ close(e);
+ }
+ }
+
+
+ private void processSocketRead() throws IOException {
+ while (true) {
+ switch (getReadState()) {
+ case WAITING:
+ if (!changeReadState(ReadState.WAITING, ReadState.PROCESSING)) {
+ continue;
+ }
+ while (response.hasRemaining()) {
+ if (isSuspended()) {
+ if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) {
+ continue;
+ }
+ // There is still data available in the response buffer
+ // Return here so that the response buffer will not be
+ // cleared and there will be no data read from the
+ // socket. Thus when the read operation is resumed first
+ // the data left in the response buffer will be consumed
+ // and then a new socket read will be performed
+ return;
+ }
+ inputBuffer.mark();
+ inputBuffer.position(inputBuffer.limit()).limit(inputBuffer.capacity());
+
+ int toCopy = Math.min(response.remaining(), inputBuffer.remaining());
+
+ // Copy remaining bytes read in HTTP phase to input buffer used by
+ // frame processing
+
+ int orgLimit = response.limit();
+ response.limit(response.position() + toCopy);
+ inputBuffer.put(response);
+ response.limit(orgLimit);
+
+ inputBuffer.limit(inputBuffer.position()).reset();
+
+ // Process the data we have
+ processInputBuffer();
+ }
+ response.clear();
+
+ // Get some more data
+ if (isOpen()) {
+ channel.read(response, null, handler);
+ } else {
+ changeReadState(ReadState.CLOSING);
+ }
+ return;
+ case SUSPENDING_WAIT:
+ if (!changeReadState(ReadState.SUSPENDING_WAIT, ReadState.SUSPENDED)) {
+ continue;
+ }
+ return;
+ default:
+ throw new IllegalStateException(
+ sm.getString("wsFrameServer.illegalReadState", getReadState()));
+ }
+ }
+ }
+
+
+ private final void close(Throwable t) {
+ changeReadState(ReadState.CLOSING);
+ CloseReason cr;
+ if (t instanceof WsIOException) {
+ cr = ((WsIOException) t).getCloseReason();
+ } else {
+ cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, t.getMessage());
+ }
+
+ try {
+ wsSession.close(cr);
+ } catch (IOException ignore) {
+ // Ignore
+ }
+ }
+
+
+ @Override
+ protected boolean isMasked() {
+ // Data is from the server so it is not masked
+ return false;
+ }
+
+
+ @Override
+ protected Log getLog() {
+ return log;
+ }
+
+ private class WsFrameClientCompletionHandler implements CompletionHandler<Integer, Void> {
+
+ @Override
+ public void completed(Integer result, Void attachment) {
+ if (result.intValue() == -1) {
+ // BZ 57762. A dropped connection will get reported as EOF
+ // rather than as an error so handle it here.
+ if (isOpen()) {
+ // No close frame was received
+ close(new EOFException());
+ }
+ // No data to process
+ return;
+ }
+ response.flip();
+ doResumeProcessing(true);
+ }
+
+ @Override
+ public void failed(Throwable exc, Void attachment) {
+ if (exc instanceof ReadBufferOverflowException) {
+ // response will be empty if this exception is thrown
+ response = ByteBuffer
+ .allocate(((ReadBufferOverflowException) exc).getMinBufferSize());
+ response.flip();
+ doResumeProcessing(false);
+ } else {
+ close(exc);
+ }
+ }
+
+ private void doResumeProcessing(boolean checkOpenOnError) {
+ while (true) {
+ switch (getReadState()) {
+ case PROCESSING:
+ if (!changeReadState(ReadState.PROCESSING, ReadState.WAITING)) {
+ continue;
+ }
+ resumeProcessing(checkOpenOnError);
+ return;
+ case SUSPENDING_PROCESS:
+ if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) {
+ continue;
+ }
+ return;
+ default:
+ throw new IllegalStateException(
+ sm.getString("wsFrame.illegalReadState", getReadState()));
+ }
+ }
+ }
+ }
+
+
+ @Override
+ protected void resumeProcessing() {
+ resumeProcessing(true);
+ }
+
+ private void resumeProcessing(boolean checkOpenOnError) {
+ try {
+ processSocketRead();
+ } catch (IOException e) {
+ if (checkOpenOnError) {
+ // Only send a close message on an IOException if the client
+ // has not yet received a close control message from the server
+ // as the IOException may be in response to the client
+ // continuing to send a message after the server sent a close
+ // control message.
+ if (isOpen()) {
+ if (log.isDebugEnabled()) {
+ log.debug(sm.getString("wsFrameClient.ioe"), e);
+ }
+ close(e);
+ }
+ } else {
+ close(e);
+ }
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsHandshakeResponse.java b/src/java/nginx/unit/websocket/WsHandshakeResponse.java
new file mode 100644
index 00000000..6e57ffd5
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsHandshakeResponse.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+import javax.websocket.HandshakeResponse;
+
+import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
+
+/**
+ * Represents the response to a WebSocket handshake.
+ */
+public class WsHandshakeResponse implements HandshakeResponse {
+
+ private final Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>();
+
+
+ public WsHandshakeResponse() {
+ }
+
+
+ public WsHandshakeResponse(Map<String,List<String>> headers) {
+ for (Entry<String,List<String>> entry : headers.entrySet()) {
+ if (this.headers.containsKey(entry.getKey())) {
+ this.headers.get(entry.getKey()).addAll(entry.getValue());
+ } else {
+ List<String> values = new ArrayList<>(entry.getValue());
+ this.headers.put(entry.getKey(), values);
+ }
+ }
+ }
+
+
+ @Override
+ public Map<String,List<String>> getHeaders() {
+ return headers;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsIOException.java b/src/java/nginx/unit/websocket/WsIOException.java
new file mode 100644
index 00000000..0362dc1d
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsIOException.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+
+import javax.websocket.CloseReason;
+
+/**
+ * Allows the WebSocket implementation to throw an {@link IOException} that
+ * includes a {@link CloseReason} specific to the error that can be passed back
+ * to the client.
+ */
+public class WsIOException extends IOException {
+
+ private static final long serialVersionUID = 1L;
+
+ private final CloseReason closeReason;
+
+ public WsIOException(CloseReason closeReason) {
+ this.closeReason = closeReason;
+ }
+
+ public CloseReason getCloseReason() {
+ return closeReason;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsPongMessage.java b/src/java/nginx/unit/websocket/WsPongMessage.java
new file mode 100644
index 00000000..531bcda9
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsPongMessage.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.nio.ByteBuffer;
+
+import javax.websocket.PongMessage;
+
+public class WsPongMessage implements PongMessage {
+
+ private final ByteBuffer applicationData;
+
+
+ public WsPongMessage(ByteBuffer applicationData) {
+ byte[] dst = new byte[applicationData.limit()];
+ applicationData.get(dst);
+ this.applicationData = ByteBuffer.wrap(dst);
+ }
+
+
+ @Override
+ public ByteBuffer getApplicationData() {
+ return applicationData;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java b/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java
new file mode 100644
index 00000000..0ea20795
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.nio.ByteBuffer;
+import java.util.concurrent.Future;
+
+import javax.websocket.RemoteEndpoint;
+import javax.websocket.SendHandler;
+
+public class WsRemoteEndpointAsync extends WsRemoteEndpointBase
+ implements RemoteEndpoint.Async {
+
+ WsRemoteEndpointAsync(WsRemoteEndpointImplBase base) {
+ super(base);
+ }
+
+
+ @Override
+ public long getSendTimeout() {
+ return base.getSendTimeout();
+ }
+
+
+ @Override
+ public void setSendTimeout(long timeout) {
+ base.setSendTimeout(timeout);
+ }
+
+
+ @Override
+ public void sendText(String text, SendHandler completion) {
+ base.sendStringByCompletion(text, completion);
+ }
+
+
+ @Override
+ public Future<Void> sendText(String text) {
+ return base.sendStringByFuture(text);
+ }
+
+
+ @Override
+ public Future<Void> sendBinary(ByteBuffer data) {
+ return base.sendBytesByFuture(data);
+ }
+
+
+ @Override
+ public void sendBinary(ByteBuffer data, SendHandler completion) {
+ base.sendBytesByCompletion(data, completion);
+ }
+
+
+ @Override
+ public Future<Void> sendObject(Object obj) {
+ return base.sendObjectByFuture(obj);
+ }
+
+
+ @Override
+ public void sendObject(Object obj, SendHandler completion) {
+ base.sendObjectByCompletion(obj, completion);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java
new file mode 100644
index 00000000..21cb2040
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+
+import javax.websocket.RemoteEndpoint;
+
+public abstract class WsRemoteEndpointBase implements RemoteEndpoint {
+
+ protected final WsRemoteEndpointImplBase base;
+
+
+ WsRemoteEndpointBase(WsRemoteEndpointImplBase base) {
+ this.base = base;
+ }
+
+
+ @Override
+ public final void setBatchingAllowed(boolean batchingAllowed) throws IOException {
+ base.setBatchingAllowed(batchingAllowed);
+ }
+
+
+ @Override
+ public final boolean getBatchingAllowed() {
+ return base.getBatchingAllowed();
+ }
+
+
+ @Override
+ public final void flushBatch() throws IOException {
+ base.flushBatch();
+ }
+
+
+ @Override
+ public final void sendPing(ByteBuffer applicationData) throws IOException,
+ IllegalArgumentException {
+ base.sendPing(applicationData);
+ }
+
+
+ @Override
+ public final void sendPong(ByteBuffer applicationData) throws IOException,
+ IllegalArgumentException {
+ base.sendPong(applicationData);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java b/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java
new file mode 100644
index 00000000..2a93cc7b
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Writer;
+import java.nio.ByteBuffer;
+
+import javax.websocket.EncodeException;
+import javax.websocket.RemoteEndpoint;
+
+public class WsRemoteEndpointBasic extends WsRemoteEndpointBase
+ implements RemoteEndpoint.Basic {
+
+ WsRemoteEndpointBasic(WsRemoteEndpointImplBase base) {
+ super(base);
+ }
+
+
+ @Override
+ public void sendText(String text) throws IOException {
+ base.sendString(text);
+ }
+
+
+ @Override
+ public void sendBinary(ByteBuffer data) throws IOException {
+ base.sendBytes(data);
+ }
+
+
+ @Override
+ public void sendText(String fragment, boolean isLast) throws IOException {
+ base.sendPartialString(fragment, isLast);
+ }
+
+
+ @Override
+ public void sendBinary(ByteBuffer partialByte, boolean isLast)
+ throws IOException {
+ base.sendPartialBytes(partialByte, isLast);
+ }
+
+
+ @Override
+ public OutputStream getSendStream() throws IOException {
+ return base.getSendStream();
+ }
+
+
+ @Override
+ public Writer getSendWriter() throws IOException {
+ return base.getSendWriter();
+ }
+
+
+ @Override
+ public void sendObject(Object o) throws IOException, EncodeException {
+ base.sendObject(o);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java
new file mode 100644
index 00000000..776124fd
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java
@@ -0,0 +1,1234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Writer;
+import java.net.SocketTimeoutException;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.charset.CharsetEncoder;
+import java.nio.charset.CoderResult;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.Future;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.DeploymentException;
+import javax.websocket.EncodeException;
+import javax.websocket.Encoder;
+import javax.websocket.EndpointConfig;
+import javax.websocket.RemoteEndpoint;
+import javax.websocket.SendHandler;
+import javax.websocket.SendResult;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.util.buf.Utf8Encoder;
+import org.apache.tomcat.util.res.StringManager;
+
+import nginx.unit.Request;
+
+public abstract class WsRemoteEndpointImplBase implements RemoteEndpoint {
+
+ private static final StringManager sm =
+ StringManager.getManager(WsRemoteEndpointImplBase.class);
+
+ protected static final SendResult SENDRESULT_OK = new SendResult();
+
+ private final Log log = LogFactory.getLog(WsRemoteEndpointImplBase.class); // must not be static
+
+ private final StateMachine stateMachine = new StateMachine();
+
+ private final IntermediateMessageHandler intermediateMessageHandler =
+ new IntermediateMessageHandler(this);
+
+ private Transformation transformation = null;
+ private final Semaphore messagePartInProgress = new Semaphore(1);
+ private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>();
+ private final Object messagePartLock = new Object();
+
+ // State
+ private volatile boolean closed = false;
+ private boolean fragmented = false;
+ private boolean nextFragmented = false;
+ private boolean text = false;
+ private boolean nextText = false;
+
+ // Max size of WebSocket header is 14 bytes
+ private final ByteBuffer headerBuffer = ByteBuffer.allocate(14);
+ private final ByteBuffer outputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final CharsetEncoder encoder = new Utf8Encoder();
+ private final ByteBuffer encoderBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final AtomicBoolean batchingAllowed = new AtomicBoolean(false);
+ private volatile long sendTimeout = -1;
+ private WsSession wsSession;
+ private List<EncoderEntry> encoderEntries = new ArrayList<>();
+
+ private Request request;
+
+
+ protected void setTransformation(Transformation transformation) {
+ this.transformation = transformation;
+ }
+
+
+ public long getSendTimeout() {
+ return sendTimeout;
+ }
+
+
+ public void setSendTimeout(long timeout) {
+ this.sendTimeout = timeout;
+ }
+
+
+ @Override
+ public void setBatchingAllowed(boolean batchingAllowed) throws IOException {
+ boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed);
+
+ if (oldValue && !batchingAllowed) {
+ flushBatch();
+ }
+ }
+
+
+ @Override
+ public boolean getBatchingAllowed() {
+ return batchingAllowed.get();
+ }
+
+
+ @Override
+ public void flushBatch() throws IOException {
+ sendMessageBlock(Constants.INTERNAL_OPCODE_FLUSH, null, true);
+ }
+
+
+ public void sendBytes(ByteBuffer data) throws IOException {
+ if (data == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ stateMachine.binaryStart();
+ sendMessageBlock(Constants.OPCODE_BINARY, data, true);
+ stateMachine.complete(true);
+ }
+
+
+ public Future<Void> sendBytesByFuture(ByteBuffer data) {
+ FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
+ sendBytesByCompletion(data, f2sh);
+ return f2sh;
+ }
+
+
+ public void sendBytesByCompletion(ByteBuffer data, SendHandler handler) {
+ if (data == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ if (handler == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
+ }
+ StateUpdateSendHandler sush = new StateUpdateSendHandler(handler, stateMachine);
+ stateMachine.binaryStart();
+ startMessage(Constants.OPCODE_BINARY, data, true, sush);
+ }
+
+
+ public void sendPartialBytes(ByteBuffer partialByte, boolean last)
+ throws IOException {
+ if (partialByte == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ stateMachine.binaryPartialStart();
+ sendMessageBlock(Constants.OPCODE_BINARY, partialByte, last);
+ stateMachine.complete(last);
+ }
+
+
+ @Override
+ public void sendPing(ByteBuffer applicationData) throws IOException,
+ IllegalArgumentException {
+ if (applicationData.remaining() > 125) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData"));
+ }
+ sendMessageBlock(Constants.OPCODE_PING, applicationData, true);
+ }
+
+
+ @Override
+ public void sendPong(ByteBuffer applicationData) throws IOException,
+ IllegalArgumentException {
+ if (applicationData.remaining() > 125) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData"));
+ }
+ sendMessageBlock(Constants.OPCODE_PONG, applicationData, true);
+ }
+
+
+ public void sendString(String text) throws IOException {
+ if (text == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ stateMachine.textStart();
+ sendMessageBlock(CharBuffer.wrap(text), true);
+ }
+
+
+ public Future<Void> sendStringByFuture(String text) {
+ FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
+ sendStringByCompletion(text, f2sh);
+ return f2sh;
+ }
+
+
+ public void sendStringByCompletion(String text, SendHandler handler) {
+ if (text == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ if (handler == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
+ }
+ stateMachine.textStart();
+ TextMessageSendHandler tmsh = new TextMessageSendHandler(handler,
+ CharBuffer.wrap(text), true, encoder, encoderBuffer, this);
+ tmsh.write();
+ // TextMessageSendHandler will update stateMachine when it completes
+ }
+
+
+ public void sendPartialString(String fragment, boolean isLast)
+ throws IOException {
+ if (fragment == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ stateMachine.textPartialStart();
+ sendMessageBlock(CharBuffer.wrap(fragment), isLast);
+ }
+
+
+ public OutputStream getSendStream() {
+ stateMachine.streamStart();
+ return new WsOutputStream(this);
+ }
+
+
+ public Writer getSendWriter() {
+ stateMachine.writeStart();
+ return new WsWriter(this);
+ }
+
+
+ void sendMessageBlock(CharBuffer part, boolean last) throws IOException {
+ long timeoutExpiry = getTimeoutExpiry();
+ boolean isDone = false;
+ while (!isDone) {
+ encoderBuffer.clear();
+ CoderResult cr = encoder.encode(part, encoderBuffer, true);
+ if (cr.isError()) {
+ throw new IllegalArgumentException(cr.toString());
+ }
+ isDone = !cr.isOverflow();
+ encoderBuffer.flip();
+ sendMessageBlock(Constants.OPCODE_TEXT, encoderBuffer, last && isDone, timeoutExpiry);
+ }
+ stateMachine.complete(last);
+ }
+
+
+ void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last)
+ throws IOException {
+ sendMessageBlock(opCode, payload, last, getTimeoutExpiry());
+ }
+
+
+ private long getTimeoutExpiry() {
+ // Get the timeout before we send the message. The message may
+ // trigger a session close and depending on timing the client
+ // session may close before we can read the timeout.
+ long timeout = getBlockingSendTimeout();
+ if (timeout < 0) {
+ return Long.MAX_VALUE;
+ } else {
+ return System.currentTimeMillis() + timeout;
+ }
+ }
+
+ private byte currentOpCode = Constants.OPCODE_CONTINUATION;
+
+ private void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last,
+ long timeoutExpiry) throws IOException {
+ wsSession.updateLastActive();
+
+ if (opCode == currentOpCode) {
+ opCode = Constants.OPCODE_CONTINUATION;
+ }
+
+ request.sendWsFrame(payload, opCode, last, timeoutExpiry);
+
+ if (!last && opCode != Constants.OPCODE_CONTINUATION) {
+ currentOpCode = opCode;
+ }
+
+ if (last && opCode == Constants.OPCODE_CONTINUATION) {
+ currentOpCode = Constants.OPCODE_CONTINUATION;
+ }
+ }
+
+
+ void startMessage(byte opCode, ByteBuffer payload, boolean last,
+ SendHandler handler) {
+
+ wsSession.updateLastActive();
+
+ List<MessagePart> messageParts = new ArrayList<>();
+ messageParts.add(new MessagePart(last, 0, opCode, payload,
+ intermediateMessageHandler,
+ new EndMessageHandler(this, handler), -1));
+
+ messageParts = transformation.sendMessagePart(messageParts);
+
+ // Some extensions/transformations may buffer messages so it is possible
+ // that no message parts will be returned. If this is the case the
+ // trigger the supplied SendHandler
+ if (messageParts.size() == 0) {
+ handler.onResult(new SendResult());
+ return;
+ }
+
+ MessagePart mp = messageParts.remove(0);
+
+ boolean doWrite = false;
+ synchronized (messagePartLock) {
+ if (Constants.OPCODE_CLOSE == mp.getOpCode() && getBatchingAllowed()) {
+ // Should not happen. To late to send batched messages now since
+ // the session has been closed. Complain loudly.
+ log.warn(sm.getString("wsRemoteEndpoint.flushOnCloseFailed"));
+ }
+ if (messagePartInProgress.tryAcquire()) {
+ doWrite = true;
+ } else {
+ // When a control message is sent while another message is being
+ // sent, the control message is queued. Chances are the
+ // subsequent data message part will end up queued while the
+ // control message is sent. The logic in this class (state
+ // machine, EndMessageHandler, TextMessageSendHandler) ensures
+ // that there will only ever be one data message part in the
+ // queue. There could be multiple control messages in the queue.
+
+ // Add it to the queue
+ messagePartQueue.add(mp);
+ }
+ // Add any remaining messages to the queue
+ messagePartQueue.addAll(messageParts);
+ }
+ if (doWrite) {
+ // Actual write has to be outside sync block to avoid possible
+ // deadlock between messagePartLock and writeLock in
+ // o.a.coyote.http11.upgrade.AbstractServletOutputStream
+ writeMessagePart(mp);
+ }
+ }
+
+
+ void endMessage(SendHandler handler, SendResult result) {
+ boolean doWrite = false;
+ MessagePart mpNext = null;
+ synchronized (messagePartLock) {
+
+ fragmented = nextFragmented;
+ text = nextText;
+
+ mpNext = messagePartQueue.poll();
+ if (mpNext == null) {
+ messagePartInProgress.release();
+ } else if (!closed){
+ // Session may have been closed unexpectedly in the middle of
+ // sending a fragmented message closing the endpoint. If this
+ // happens, clearly there is no point trying to send the rest of
+ // the message.
+ doWrite = true;
+ }
+ }
+ if (doWrite) {
+ // Actual write has to be outside sync block to avoid possible
+ // deadlock between messagePartLock and writeLock in
+ // o.a.coyote.http11.upgrade.AbstractServletOutputStream
+ writeMessagePart(mpNext);
+ }
+
+ wsSession.updateLastActive();
+
+ // Some handlers, such as the IntermediateMessageHandler, do not have a
+ // nested handler so handler may be null.
+ if (handler != null) {
+ handler.onResult(result);
+ }
+ }
+
+
+ void writeMessagePart(MessagePart mp) {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closed"));
+ }
+
+ if (Constants.INTERNAL_OPCODE_FLUSH == mp.getOpCode()) {
+ nextFragmented = fragmented;
+ nextText = text;
+ outputBuffer.flip();
+ SendHandler flushHandler = new OutputBufferFlushSendHandler(
+ outputBuffer, mp.getEndHandler());
+ doWrite(flushHandler, mp.getBlockingWriteTimeoutExpiry(), outputBuffer);
+ return;
+ }
+
+ // Control messages may be sent in the middle of fragmented message
+ // so they have no effect on the fragmented or text flags
+ boolean first;
+ if (Util.isControl(mp.getOpCode())) {
+ nextFragmented = fragmented;
+ nextText = text;
+ if (mp.getOpCode() == Constants.OPCODE_CLOSE) {
+ closed = true;
+ }
+ first = true;
+ } else {
+ boolean isText = Util.isText(mp.getOpCode());
+
+ if (fragmented) {
+ // Currently fragmented
+ if (text != isText) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.changeType"));
+ }
+ nextText = text;
+ nextFragmented = !mp.isFin();
+ first = false;
+ } else {
+ // Wasn't fragmented. Might be now
+ if (mp.isFin()) {
+ nextFragmented = false;
+ } else {
+ nextFragmented = true;
+ nextText = isText;
+ }
+ first = true;
+ }
+ }
+
+ byte[] mask;
+
+ if (isMasked()) {
+ mask = Util.generateMask();
+ } else {
+ mask = null;
+ }
+
+ headerBuffer.clear();
+ writeHeader(headerBuffer, mp.isFin(), mp.getRsv(), mp.getOpCode(),
+ isMasked(), mp.getPayload(), mask, first);
+ headerBuffer.flip();
+
+ if (getBatchingAllowed() || isMasked()) {
+ // Need to write via output buffer
+ OutputBufferSendHandler obsh = new OutputBufferSendHandler(
+ mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(),
+ headerBuffer, mp.getPayload(), mask,
+ outputBuffer, !getBatchingAllowed(), this);
+ obsh.write();
+ } else {
+ // Can write directly
+ doWrite(mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(),
+ headerBuffer, mp.getPayload());
+ }
+ }
+
+
+ private long getBlockingSendTimeout() {
+ Object obj = wsSession.getUserProperties().get(Constants.BLOCKING_SEND_TIMEOUT_PROPERTY);
+ Long userTimeout = null;
+ if (obj instanceof Long) {
+ userTimeout = (Long) obj;
+ }
+ if (userTimeout == null) {
+ return Constants.DEFAULT_BLOCKING_SEND_TIMEOUT;
+ } else {
+ return userTimeout.longValue();
+ }
+ }
+
+
+ /**
+ * Wraps the user provided handler so that the end point is notified when
+ * the message is complete.
+ */
+ private static class EndMessageHandler implements SendHandler {
+
+ private final WsRemoteEndpointImplBase endpoint;
+ private final SendHandler handler;
+
+ public EndMessageHandler(WsRemoteEndpointImplBase endpoint,
+ SendHandler handler) {
+ this.endpoint = endpoint;
+ this.handler = handler;
+ }
+
+
+ @Override
+ public void onResult(SendResult result) {
+ endpoint.endMessage(handler, result);
+ }
+ }
+
+
+ /**
+ * If a transformation needs to split a {@link MessagePart} into multiple
+ * {@link MessagePart}s, it uses this handler as the end handler for each of
+ * the additional {@link MessagePart}s. This handler notifies this this
+ * class that the {@link MessagePart} has been processed and that the next
+ * {@link MessagePart} in the queue should be started. The final
+ * {@link MessagePart} will use the {@link EndMessageHandler} provided with
+ * the original {@link MessagePart}.
+ */
+ private static class IntermediateMessageHandler implements SendHandler {
+
+ private final WsRemoteEndpointImplBase endpoint;
+
+ public IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint) {
+ this.endpoint = endpoint;
+ }
+
+
+ @Override
+ public void onResult(SendResult result) {
+ endpoint.endMessage(null, result);
+ }
+ }
+
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public void sendObject(Object obj) throws IOException, EncodeException {
+ if (obj == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ /*
+ * Note that the implementation will convert primitives and their object
+ * equivalents by default but that users are free to specify their own
+ * encoders and decoders for this if they wish.
+ */
+ Encoder encoder = findEncoder(obj);
+ if (encoder == null && Util.isPrimitive(obj.getClass())) {
+ String msg = obj.toString();
+ sendString(msg);
+ return;
+ }
+ if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) {
+ ByteBuffer msg = ByteBuffer.wrap((byte[]) obj);
+ sendBytes(msg);
+ return;
+ }
+
+ if (encoder instanceof Encoder.Text) {
+ String msg = ((Encoder.Text) encoder).encode(obj);
+ sendString(msg);
+ } else if (encoder instanceof Encoder.TextStream) {
+ try (Writer w = getSendWriter()) {
+ ((Encoder.TextStream) encoder).encode(obj, w);
+ }
+ } else if (encoder instanceof Encoder.Binary) {
+ ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj);
+ sendBytes(msg);
+ } else if (encoder instanceof Encoder.BinaryStream) {
+ try (OutputStream os = getSendStream()) {
+ ((Encoder.BinaryStream) encoder).encode(obj, os);
+ }
+ } else {
+ throw new EncodeException(obj, sm.getString(
+ "wsRemoteEndpoint.noEncoder", obj.getClass()));
+ }
+ }
+
+
+ public Future<Void> sendObjectByFuture(Object obj) {
+ FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
+ sendObjectByCompletion(obj, f2sh);
+ return f2sh;
+ }
+
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public void sendObjectByCompletion(Object obj, SendHandler completion) {
+
+ if (obj == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ if (completion == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
+ }
+
+ /*
+ * Note that the implementation will convert primitives and their object
+ * equivalents by default but that users are free to specify their own
+ * encoders and decoders for this if they wish.
+ */
+ Encoder encoder = findEncoder(obj);
+ if (encoder == null && Util.isPrimitive(obj.getClass())) {
+ String msg = obj.toString();
+ sendStringByCompletion(msg, completion);
+ return;
+ }
+ if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) {
+ ByteBuffer msg = ByteBuffer.wrap((byte[]) obj);
+ sendBytesByCompletion(msg, completion);
+ return;
+ }
+
+ try {
+ if (encoder instanceof Encoder.Text) {
+ String msg = ((Encoder.Text) encoder).encode(obj);
+ sendStringByCompletion(msg, completion);
+ } else if (encoder instanceof Encoder.TextStream) {
+ try (Writer w = getSendWriter()) {
+ ((Encoder.TextStream) encoder).encode(obj, w);
+ }
+ completion.onResult(new SendResult());
+ } else if (encoder instanceof Encoder.Binary) {
+ ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj);
+ sendBytesByCompletion(msg, completion);
+ } else if (encoder instanceof Encoder.BinaryStream) {
+ try (OutputStream os = getSendStream()) {
+ ((Encoder.BinaryStream) encoder).encode(obj, os);
+ }
+ completion.onResult(new SendResult());
+ } else {
+ throw new EncodeException(obj, sm.getString(
+ "wsRemoteEndpoint.noEncoder", obj.getClass()));
+ }
+ } catch (Exception e) {
+ SendResult sr = new SendResult(e);
+ completion.onResult(sr);
+ }
+ }
+
+
+ protected void setSession(WsSession wsSession) {
+ this.wsSession = wsSession;
+ }
+
+
+ protected void setRequest(Request request) {
+ this.request = request;
+ }
+
+ protected void setEncoders(EndpointConfig endpointConfig)
+ throws DeploymentException {
+ encoderEntries.clear();
+ for (Class<? extends Encoder> encoderClazz :
+ endpointConfig.getEncoders()) {
+ Encoder instance;
+ try {
+ instance = encoderClazz.getConstructor().newInstance();
+ instance.init(endpointConfig);
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(
+ sm.getString("wsRemoteEndpoint.invalidEncoder",
+ encoderClazz.getName()), e);
+ }
+ EncoderEntry entry = new EncoderEntry(
+ Util.getEncoderType(encoderClazz), instance);
+ encoderEntries.add(entry);
+ }
+ }
+
+
+ private Encoder findEncoder(Object obj) {
+ for (EncoderEntry entry : encoderEntries) {
+ if (entry.getClazz().isAssignableFrom(obj.getClass())) {
+ return entry.getEncoder();
+ }
+ }
+ return null;
+ }
+
+
+ public final void close() {
+ for (EncoderEntry entry : encoderEntries) {
+ entry.getEncoder().destroy();
+ }
+
+ request.closeWs();
+ }
+
+
+ protected abstract void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry,
+ ByteBuffer... data);
+ protected abstract boolean isMasked();
+ protected abstract void doClose();
+
+ private static void writeHeader(ByteBuffer headerBuffer, boolean fin,
+ int rsv, byte opCode, boolean masked, ByteBuffer payload,
+ byte[] mask, boolean first) {
+
+ byte b = 0;
+
+ if (fin) {
+ // Set the fin bit
+ b -= 128;
+ }
+
+ b += (rsv << 4);
+
+ if (first) {
+ // This is the first fragment of this message
+ b += opCode;
+ }
+ // If not the first fragment, it is a continuation with opCode of zero
+
+ headerBuffer.put(b);
+
+ if (masked) {
+ b = (byte) 0x80;
+ } else {
+ b = 0;
+ }
+
+ // Next write the mask && length length
+ if (payload.limit() < 126) {
+ headerBuffer.put((byte) (payload.limit() | b));
+ } else if (payload.limit() < 65536) {
+ headerBuffer.put((byte) (126 | b));
+ headerBuffer.put((byte) (payload.limit() >>> 8));
+ headerBuffer.put((byte) (payload.limit() & 0xFF));
+ } else {
+ // Will never be more than 2^31-1
+ headerBuffer.put((byte) (127 | b));
+ headerBuffer.put((byte) 0);
+ headerBuffer.put((byte) 0);
+ headerBuffer.put((byte) 0);
+ headerBuffer.put((byte) 0);
+ headerBuffer.put((byte) (payload.limit() >>> 24));
+ headerBuffer.put((byte) (payload.limit() >>> 16));
+ headerBuffer.put((byte) (payload.limit() >>> 8));
+ headerBuffer.put((byte) (payload.limit() & 0xFF));
+ }
+ if (masked) {
+ headerBuffer.put(mask[0]);
+ headerBuffer.put(mask[1]);
+ headerBuffer.put(mask[2]);
+ headerBuffer.put(mask[3]);
+ }
+ }
+
+
+ private class TextMessageSendHandler implements SendHandler {
+
+ private final SendHandler handler;
+ private final CharBuffer message;
+ private final boolean isLast;
+ private final CharsetEncoder encoder;
+ private final ByteBuffer buffer;
+ private final WsRemoteEndpointImplBase endpoint;
+ private volatile boolean isDone = false;
+
+ public TextMessageSendHandler(SendHandler handler, CharBuffer message,
+ boolean isLast, CharsetEncoder encoder,
+ ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint) {
+ this.handler = handler;
+ this.message = message;
+ this.isLast = isLast;
+ this.encoder = encoder.reset();
+ this.buffer = encoderBuffer;
+ this.endpoint = endpoint;
+ }
+
+ public void write() {
+ buffer.clear();
+ CoderResult cr = encoder.encode(message, buffer, true);
+ if (cr.isError()) {
+ throw new IllegalArgumentException(cr.toString());
+ }
+ isDone = !cr.isOverflow();
+ buffer.flip();
+ endpoint.startMessage(Constants.OPCODE_TEXT, buffer,
+ isDone && isLast, this);
+ }
+
+ @Override
+ public void onResult(SendResult result) {
+ if (isDone) {
+ endpoint.stateMachine.complete(isLast);
+ handler.onResult(result);
+ } else if(!result.isOK()) {
+ handler.onResult(result);
+ } else if (closed){
+ SendResult sr = new SendResult(new IOException(
+ sm.getString("wsRemoteEndpoint.closedDuringMessage")));
+ handler.onResult(sr);
+ } else {
+ write();
+ }
+ }
+ }
+
+
+ /**
+ * Used to write data to the output buffer, flushing the buffer if it fills
+ * up.
+ */
+ private static class OutputBufferSendHandler implements SendHandler {
+
+ private final SendHandler handler;
+ private final long blockingWriteTimeoutExpiry;
+ private final ByteBuffer headerBuffer;
+ private final ByteBuffer payload;
+ private final byte[] mask;
+ private final ByteBuffer outputBuffer;
+ private final boolean flushRequired;
+ private final WsRemoteEndpointImplBase endpoint;
+ private int maskIndex = 0;
+
+ public OutputBufferSendHandler(SendHandler completion,
+ long blockingWriteTimeoutExpiry,
+ ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask,
+ ByteBuffer outputBuffer, boolean flushRequired,
+ WsRemoteEndpointImplBase endpoint) {
+ this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry;
+ this.handler = completion;
+ this.headerBuffer = headerBuffer;
+ this.payload = payload;
+ this.mask = mask;
+ this.outputBuffer = outputBuffer;
+ this.flushRequired = flushRequired;
+ this.endpoint = endpoint;
+ }
+
+ public void write() {
+ // Write the header
+ while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) {
+ outputBuffer.put(headerBuffer.get());
+ }
+ if (headerBuffer.hasRemaining()) {
+ // Still more headers to write, need to flush
+ outputBuffer.flip();
+ endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
+ return;
+ }
+
+ // Write the payload
+ int payloadLeft = payload.remaining();
+ int payloadLimit = payload.limit();
+ int outputSpace = outputBuffer.remaining();
+ int toWrite = payloadLeft;
+
+ if (payloadLeft > outputSpace) {
+ toWrite = outputSpace;
+ // Temporarily reduce the limit
+ payload.limit(payload.position() + toWrite);
+ }
+
+ if (mask == null) {
+ // Use a bulk copy
+ outputBuffer.put(payload);
+ } else {
+ for (int i = 0; i < toWrite; i++) {
+ outputBuffer.put(
+ (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF)));
+ if (maskIndex > 3) {
+ maskIndex = 0;
+ }
+ }
+ }
+
+ if (payloadLeft > outputSpace) {
+ // Restore the original limit
+ payload.limit(payloadLimit);
+ // Still more data to write, need to flush
+ outputBuffer.flip();
+ endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
+ return;
+ }
+
+ if (flushRequired) {
+ outputBuffer.flip();
+ if (outputBuffer.remaining() == 0) {
+ handler.onResult(SENDRESULT_OK);
+ } else {
+ endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
+ }
+ } else {
+ handler.onResult(SENDRESULT_OK);
+ }
+ }
+
+ // ------------------------------------------------- SendHandler methods
+ @Override
+ public void onResult(SendResult result) {
+ if (result.isOK()) {
+ if (outputBuffer.hasRemaining()) {
+ endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
+ } else {
+ outputBuffer.clear();
+ write();
+ }
+ } else {
+ handler.onResult(result);
+ }
+ }
+ }
+
+
+ /**
+ * Ensures that the output buffer is cleared after it has been flushed.
+ */
+ private static class OutputBufferFlushSendHandler implements SendHandler {
+
+ private final ByteBuffer outputBuffer;
+ private final SendHandler handler;
+
+ public OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler) {
+ this.outputBuffer = outputBuffer;
+ this.handler = handler;
+ }
+
+ @Override
+ public void onResult(SendResult result) {
+ if (result.isOK()) {
+ outputBuffer.clear();
+ }
+ handler.onResult(result);
+ }
+ }
+
+
+ private static class WsOutputStream extends OutputStream {
+
+ private final WsRemoteEndpointImplBase endpoint;
+ private final ByteBuffer buffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final Object closeLock = new Object();
+ private volatile boolean closed = false;
+ private volatile boolean used = false;
+
+ public WsOutputStream(WsRemoteEndpointImplBase endpoint) {
+ this.endpoint = endpoint;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedOutputStream"));
+ }
+
+ used = true;
+ if (buffer.remaining() == 0) {
+ flush();
+ }
+ buffer.put((byte) b);
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedOutputStream"));
+ }
+ if (len == 0) {
+ return;
+ }
+ if ((off < 0) || (off > b.length) || (len < 0) ||
+ ((off + len) > b.length) || ((off + len) < 0)) {
+ throw new IndexOutOfBoundsException();
+ }
+
+ used = true;
+ if (buffer.remaining() == 0) {
+ flush();
+ }
+ int remaining = buffer.remaining();
+ int written = 0;
+
+ while (remaining < len - written) {
+ buffer.put(b, off + written, remaining);
+ written += remaining;
+ flush();
+ remaining = buffer.remaining();
+ }
+ buffer.put(b, off + written, len - written);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedOutputStream"));
+ }
+
+ // Optimisation. If there is no data to flush then do not send an
+ // empty message.
+ if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) {
+ doWrite(false);
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ synchronized (closeLock) {
+ if (closed) {
+ return;
+ }
+ closed = true;
+ }
+
+ doWrite(true);
+ }
+
+ private void doWrite(boolean last) throws IOException {
+ if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) {
+ buffer.flip();
+ endpoint.sendMessageBlock(Constants.OPCODE_BINARY, buffer, last);
+ }
+ endpoint.stateMachine.complete(last);
+ buffer.clear();
+ }
+ }
+
+
+ private static class WsWriter extends Writer {
+
+ private final WsRemoteEndpointImplBase endpoint;
+ private final CharBuffer buffer = CharBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final Object closeLock = new Object();
+ private volatile boolean closed = false;
+ private volatile boolean used = false;
+
+ public WsWriter(WsRemoteEndpointImplBase endpoint) {
+ this.endpoint = endpoint;
+ }
+
+ @Override
+ public void write(char[] cbuf, int off, int len) throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedWriter"));
+ }
+ if (len == 0) {
+ return;
+ }
+ if ((off < 0) || (off > cbuf.length) || (len < 0) ||
+ ((off + len) > cbuf.length) || ((off + len) < 0)) {
+ throw new IndexOutOfBoundsException();
+ }
+
+ used = true;
+ if (buffer.remaining() == 0) {
+ flush();
+ }
+ int remaining = buffer.remaining();
+ int written = 0;
+
+ while (remaining < len - written) {
+ buffer.put(cbuf, off + written, remaining);
+ written += remaining;
+ flush();
+ remaining = buffer.remaining();
+ }
+ buffer.put(cbuf, off + written, len - written);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedWriter"));
+ }
+
+ if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) {
+ doWrite(false);
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ synchronized (closeLock) {
+ if (closed) {
+ return;
+ }
+ closed = true;
+ }
+
+ doWrite(true);
+ }
+
+ private void doWrite(boolean last) throws IOException {
+ if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) {
+ buffer.flip();
+ endpoint.sendMessageBlock(buffer, last);
+ buffer.clear();
+ } else {
+ endpoint.stateMachine.complete(last);
+ }
+ }
+ }
+
+
+ private static class EncoderEntry {
+
+ private final Class<?> clazz;
+ private final Encoder encoder;
+
+ public EncoderEntry(Class<?> clazz, Encoder encoder) {
+ this.clazz = clazz;
+ this.encoder = encoder;
+ }
+
+ public Class<?> getClazz() {
+ return clazz;
+ }
+
+ public Encoder getEncoder() {
+ return encoder;
+ }
+ }
+
+
+ private enum State {
+ OPEN,
+ STREAM_WRITING,
+ WRITER_WRITING,
+ BINARY_PARTIAL_WRITING,
+ BINARY_PARTIAL_READY,
+ BINARY_FULL_WRITING,
+ TEXT_PARTIAL_WRITING,
+ TEXT_PARTIAL_READY,
+ TEXT_FULL_WRITING
+ }
+
+
+ private static class StateMachine {
+ private State state = State.OPEN;
+
+ public synchronized void streamStart() {
+ checkState(State.OPEN);
+ state = State.STREAM_WRITING;
+ }
+
+ public synchronized void writeStart() {
+ checkState(State.OPEN);
+ state = State.WRITER_WRITING;
+ }
+
+ public synchronized void binaryPartialStart() {
+ checkState(State.OPEN, State.BINARY_PARTIAL_READY);
+ state = State.BINARY_PARTIAL_WRITING;
+ }
+
+ public synchronized void binaryStart() {
+ checkState(State.OPEN);
+ state = State.BINARY_FULL_WRITING;
+ }
+
+ public synchronized void textPartialStart() {
+ checkState(State.OPEN, State.TEXT_PARTIAL_READY);
+ state = State.TEXT_PARTIAL_WRITING;
+ }
+
+ public synchronized void textStart() {
+ checkState(State.OPEN);
+ state = State.TEXT_FULL_WRITING;
+ }
+
+ public synchronized void complete(boolean last) {
+ if (last) {
+ checkState(State.TEXT_PARTIAL_WRITING, State.TEXT_FULL_WRITING,
+ State.BINARY_PARTIAL_WRITING, State.BINARY_FULL_WRITING,
+ State.STREAM_WRITING, State.WRITER_WRITING);
+ state = State.OPEN;
+ } else {
+ checkState(State.TEXT_PARTIAL_WRITING, State.BINARY_PARTIAL_WRITING,
+ State.STREAM_WRITING, State.WRITER_WRITING);
+ if (state == State.TEXT_PARTIAL_WRITING) {
+ state = State.TEXT_PARTIAL_READY;
+ } else if (state == State.BINARY_PARTIAL_WRITING){
+ state = State.BINARY_PARTIAL_READY;
+ } else if (state == State.WRITER_WRITING) {
+ // NO-OP. Leave state as is.
+ } else if (state == State.STREAM_WRITING) {
+ // NO-OP. Leave state as is.
+ } else {
+ // Should never happen
+ // The if ... else ... blocks above should cover all states
+ // permitted by the preceding checkState() call
+ throw new IllegalStateException(
+ "BUG: This code should never be called");
+ }
+ }
+ }
+
+ private void checkState(State... required) {
+ for (State state : required) {
+ if (this.state == state) {
+ return;
+ }
+ }
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.wrongState", this.state));
+ }
+ }
+
+
+ private static class StateUpdateSendHandler implements SendHandler {
+
+ private final SendHandler handler;
+ private final StateMachine stateMachine;
+
+ public StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine) {
+ this.handler = handler;
+ this.stateMachine = stateMachine;
+ }
+
+ @Override
+ public void onResult(SendResult result) {
+ if (result.isOK()) {
+ stateMachine.complete(true);
+ }
+ handler.onResult(result);
+ }
+ }
+
+
+ private static class BlockingSendHandler implements SendHandler {
+
+ private SendResult sendResult = null;
+
+ @Override
+ public void onResult(SendResult result) {
+ sendResult = result;
+ }
+
+ public SendResult getSendResult() {
+ return sendResult;
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java
new file mode 100644
index 00000000..70b66789
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import javax.websocket.SendHandler;
+import javax.websocket.SendResult;
+
+public class WsRemoteEndpointImplClient extends WsRemoteEndpointImplBase {
+
+ private final AsyncChannelWrapper channel;
+
+ public WsRemoteEndpointImplClient(AsyncChannelWrapper channel) {
+ this.channel = channel;
+ }
+
+
+ @Override
+ protected boolean isMasked() {
+ return true;
+ }
+
+
+ @Override
+ protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry,
+ ByteBuffer... data) {
+ long timeout;
+ for (ByteBuffer byteBuffer : data) {
+ if (blockingWriteTimeoutExpiry == -1) {
+ timeout = getSendTimeout();
+ if (timeout < 1) {
+ timeout = Long.MAX_VALUE;
+ }
+ } else {
+ timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis();
+ if (timeout < 0) {
+ SendResult sr = new SendResult(new IOException("Blocking write timeout"));
+ handler.onResult(sr);
+ }
+ }
+
+ try {
+ channel.write(byteBuffer).get(timeout, TimeUnit.MILLISECONDS);
+ } catch (InterruptedException | ExecutionException | TimeoutException e) {
+ handler.onResult(new SendResult(e));
+ return;
+ }
+ }
+ handler.onResult(SENDRESULT_OK);
+ }
+
+ @Override
+ protected void doClose() {
+ channel.close();
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsSession.java b/src/java/nginx/unit/websocket/WsSession.java
new file mode 100644
index 00000000..b654eb37
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsSession.java
@@ -0,0 +1,1070 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.net.URI;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.CharBuffer;
+import java.nio.channels.WritePendingException;
+import java.nio.charset.CharsetDecoder;
+import java.nio.charset.CoderResult;
+import java.nio.charset.CodingErrorAction;
+import java.nio.charset.StandardCharsets;
+import java.security.Principal;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCode;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.DeploymentException;
+import javax.websocket.Endpoint;
+import javax.websocket.EndpointConfig;
+import javax.websocket.Extension;
+import javax.websocket.MessageHandler;
+import javax.websocket.MessageHandler.Partial;
+import javax.websocket.MessageHandler.Whole;
+import javax.websocket.PongMessage;
+import javax.websocket.RemoteEndpoint;
+import javax.websocket.SendResult;
+import javax.websocket.Session;
+import javax.websocket.WebSocketContainer;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.InstanceManager;
+import org.apache.tomcat.InstanceManagerBindings;
+import org.apache.tomcat.util.ExceptionUtils;
+import org.apache.tomcat.util.buf.Utf8Decoder;
+import org.apache.tomcat.util.res.StringManager;
+
+import nginx.unit.Request;
+
+public class WsSession implements Session {
+
+ // An ellipsis is a single character that looks like three periods in a row
+ // and is used to indicate a continuation.
+ private static final byte[] ELLIPSIS_BYTES = "\u2026".getBytes(StandardCharsets.UTF_8);
+ // An ellipsis is three bytes in UTF-8
+ private static final int ELLIPSIS_BYTES_LEN = ELLIPSIS_BYTES.length;
+
+ private static final StringManager sm = StringManager.getManager(WsSession.class);
+ private static AtomicLong ids = new AtomicLong(0);
+
+ private final Log log = LogFactory.getLog(WsSession.class); // must not be static
+
+ private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder().
+ onMalformedInput(CodingErrorAction.REPORT).
+ onUnmappableCharacter(CodingErrorAction.REPORT);
+
+ private final Endpoint localEndpoint;
+ private final WsRemoteEndpointImplBase wsRemoteEndpoint;
+ private final RemoteEndpoint.Async remoteEndpointAsync;
+ private final RemoteEndpoint.Basic remoteEndpointBasic;
+ private final ClassLoader applicationClassLoader;
+ private final WsWebSocketContainer webSocketContainer;
+ private final URI requestUri;
+ private final Map<String, List<String>> requestParameterMap;
+ private final String queryString;
+ private final Principal userPrincipal;
+ private final EndpointConfig endpointConfig;
+
+ private final List<Extension> negotiatedExtensions;
+ private final String subProtocol;
+ private final Map<String, String> pathParameters;
+ private final boolean secure;
+ private final String httpSessionId;
+ private final String id;
+
+ // Expected to handle message types of <String> only
+ private volatile MessageHandler textMessageHandler = null;
+ // Expected to handle message types of <ByteBuffer> only
+ private volatile MessageHandler binaryMessageHandler = null;
+ private volatile MessageHandler.Whole<PongMessage> pongMessageHandler = null;
+ private volatile State state = State.OPEN;
+ private final Object stateLock = new Object();
+ private final Map<String, Object> userProperties = new ConcurrentHashMap<>();
+ private volatile int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
+ private volatile int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
+ private volatile long maxIdleTimeout = 0;
+ private volatile long lastActive = System.currentTimeMillis();
+ private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>();
+
+ private CharBuffer messageBufferText;
+ private ByteBuffer binaryBuffer;
+ private byte startOpCode = Constants.OPCODE_CONTINUATION;
+
+ /**
+ * Creates a new WebSocket session for communication between the two
+ * provided end points. The result of {@link Thread#getContextClassLoader()}
+ * at the time this constructor is called will be used when calling
+ * {@link Endpoint#onClose(Session, CloseReason)}.
+ *
+ * @param localEndpoint The end point managed by this code
+ * @param wsRemoteEndpoint The other / remote endpoint
+ * @param wsWebSocketContainer The container that created this session
+ * @param requestUri The URI used to connect to this endpoint or
+ * <code>null</code> is this is a client session
+ * @param requestParameterMap The parameters associated with the request
+ * that initiated this session or
+ * <code>null</code> if this is a client session
+ * @param queryString The query string associated with the request
+ * that initiated this session or
+ * <code>null</code> if this is a client session
+ * @param userPrincipal The principal associated with the request
+ * that initiated this session or
+ * <code>null</code> if this is a client session
+ * @param httpSessionId The HTTP session ID associated with the
+ * request that initiated this session or
+ * <code>null</code> if this is a client session
+ * @param negotiatedExtensions The agreed extensions to use for this session
+ * @param subProtocol The agreed subprotocol to use for this
+ * session
+ * @param pathParameters The path parameters associated with the
+ * request that initiated this session or
+ * <code>null</code> if this is a client session
+ * @param secure Was this session initiated over a secure
+ * connection?
+ * @param endpointConfig The configuration information for the
+ * endpoint
+ * @throws DeploymentException if an invalid encode is specified
+ */
+ public WsSession(Endpoint localEndpoint,
+ WsRemoteEndpointImplBase wsRemoteEndpoint,
+ WsWebSocketContainer wsWebSocketContainer,
+ URI requestUri, Map<String, List<String>> requestParameterMap,
+ String queryString, Principal userPrincipal, String httpSessionId,
+ List<Extension> negotiatedExtensions, String subProtocol, Map<String, String> pathParameters,
+ boolean secure, EndpointConfig endpointConfig,
+ Request request) throws DeploymentException {
+ this.localEndpoint = localEndpoint;
+ this.wsRemoteEndpoint = wsRemoteEndpoint;
+ this.wsRemoteEndpoint.setSession(this);
+ this.wsRemoteEndpoint.setRequest(request);
+
+ request.setWsSession(this);
+
+ this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint);
+ this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint);
+ this.webSocketContainer = wsWebSocketContainer;
+ applicationClassLoader = Thread.currentThread().getContextClassLoader();
+ wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout());
+ this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize();
+ this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize();
+ this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout();
+ this.requestUri = requestUri;
+ if (requestParameterMap == null) {
+ this.requestParameterMap = Collections.emptyMap();
+ } else {
+ this.requestParameterMap = requestParameterMap;
+ }
+ this.queryString = queryString;
+ this.userPrincipal = userPrincipal;
+ this.httpSessionId = httpSessionId;
+ this.negotiatedExtensions = negotiatedExtensions;
+ if (subProtocol == null) {
+ this.subProtocol = "";
+ } else {
+ this.subProtocol = subProtocol;
+ }
+ this.pathParameters = pathParameters;
+ this.secure = secure;
+ this.wsRemoteEndpoint.setEncoders(endpointConfig);
+ this.endpointConfig = endpointConfig;
+
+ this.userProperties.putAll(endpointConfig.getUserProperties());
+ this.id = Long.toHexString(ids.getAndIncrement());
+
+ InstanceManager instanceManager = webSocketContainer.getInstanceManager();
+ if (instanceManager == null) {
+ instanceManager = InstanceManagerBindings.get(applicationClassLoader);
+ }
+ if (instanceManager != null) {
+ try {
+ instanceManager.newInstance(localEndpoint);
+ } catch (Exception e) {
+ throw new DeploymentException(sm.getString("wsSession.instanceNew"), e);
+ }
+ }
+
+ if (log.isDebugEnabled()) {
+ log.debug(sm.getString("wsSession.created", id));
+ }
+
+ messageBufferText = CharBuffer.allocate(maxTextMessageBufferSize);
+ }
+
+ public static String wsSession_test() {
+ return sm.getString("wsSession.instanceNew");
+ }
+
+
+ @Override
+ public WebSocketContainer getContainer() {
+ checkState();
+ return webSocketContainer;
+ }
+
+
+ @Override
+ public void addMessageHandler(MessageHandler listener) {
+ Class<?> target = Util.getMessageType(listener);
+ doAddMessageHandler(target, listener);
+ }
+
+
+ @Override
+ public <T> void addMessageHandler(Class<T> clazz, Partial<T> handler)
+ throws IllegalStateException {
+ doAddMessageHandler(clazz, handler);
+ }
+
+
+ @Override
+ public <T> void addMessageHandler(Class<T> clazz, Whole<T> handler)
+ throws IllegalStateException {
+ doAddMessageHandler(clazz, handler);
+ }
+
+
+ @SuppressWarnings("unchecked")
+ private void doAddMessageHandler(Class<?> target, MessageHandler listener) {
+ checkState();
+
+ // Message handlers that require decoders may map to text messages,
+ // binary messages, both or neither.
+
+ // The frame processing code expects binary message handlers to
+ // accept ByteBuffer
+
+ // Use the POJO message handler wrappers as they are designed to wrap
+ // arbitrary objects with MessageHandlers and can wrap MessageHandlers
+ // just as easily.
+
+ Set<MessageHandlerResult> mhResults = Util.getMessageHandlers(target, listener,
+ endpointConfig, this);
+
+ for (MessageHandlerResult mhResult : mhResults) {
+ switch (mhResult.getType()) {
+ case TEXT: {
+ if (textMessageHandler != null) {
+ throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText"));
+ }
+ textMessageHandler = mhResult.getHandler();
+ break;
+ }
+ case BINARY: {
+ if (binaryMessageHandler != null) {
+ throw new IllegalStateException(
+ sm.getString("wsSession.duplicateHandlerBinary"));
+ }
+ binaryMessageHandler = mhResult.getHandler();
+ break;
+ }
+ case PONG: {
+ if (pongMessageHandler != null) {
+ throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong"));
+ }
+ MessageHandler handler = mhResult.getHandler();
+ if (handler instanceof MessageHandler.Whole<?>) {
+ pongMessageHandler = (MessageHandler.Whole<PongMessage>) handler;
+ } else {
+ throw new IllegalStateException(
+ sm.getString("wsSession.invalidHandlerTypePong"));
+ }
+
+ break;
+ }
+ default: {
+ throw new IllegalArgumentException(
+ sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType()));
+ }
+ }
+ }
+ }
+
+
+ @Override
+ public Set<MessageHandler> getMessageHandlers() {
+ checkState();
+ Set<MessageHandler> result = new HashSet<>();
+ if (binaryMessageHandler != null) {
+ result.add(binaryMessageHandler);
+ }
+ if (textMessageHandler != null) {
+ result.add(textMessageHandler);
+ }
+ if (pongMessageHandler != null) {
+ result.add(pongMessageHandler);
+ }
+ return result;
+ }
+
+
+ @Override
+ public void removeMessageHandler(MessageHandler listener) {
+ checkState();
+ if (listener == null) {
+ return;
+ }
+
+ MessageHandler wrapped = null;
+
+ if (listener instanceof WrappedMessageHandler) {
+ wrapped = ((WrappedMessageHandler) listener).getWrappedHandler();
+ }
+
+ if (wrapped == null) {
+ wrapped = listener;
+ }
+
+ boolean removed = false;
+ if (wrapped.equals(textMessageHandler) || listener.equals(textMessageHandler)) {
+ textMessageHandler = null;
+ removed = true;
+ }
+
+ if (wrapped.equals(binaryMessageHandler) || listener.equals(binaryMessageHandler)) {
+ binaryMessageHandler = null;
+ removed = true;
+ }
+
+ if (wrapped.equals(pongMessageHandler) || listener.equals(pongMessageHandler)) {
+ pongMessageHandler = null;
+ removed = true;
+ }
+
+ if (!removed) {
+ // ISE for now. Could swallow this silently / log this if the ISE
+ // becomes a problem
+ throw new IllegalStateException(
+ sm.getString("wsSession.removeHandlerFailed", listener));
+ }
+ }
+
+
+ @Override
+ public String getProtocolVersion() {
+ checkState();
+ return Constants.WS_VERSION_HEADER_VALUE;
+ }
+
+
+ @Override
+ public String getNegotiatedSubprotocol() {
+ checkState();
+ return subProtocol;
+ }
+
+
+ @Override
+ public List<Extension> getNegotiatedExtensions() {
+ checkState();
+ return negotiatedExtensions;
+ }
+
+
+ @Override
+ public boolean isSecure() {
+ checkState();
+ return secure;
+ }
+
+
+ @Override
+ public boolean isOpen() {
+ return state == State.OPEN;
+ }
+
+
+ @Override
+ public long getMaxIdleTimeout() {
+ checkState();
+ return maxIdleTimeout;
+ }
+
+
+ @Override
+ public void setMaxIdleTimeout(long timeout) {
+ checkState();
+ this.maxIdleTimeout = timeout;
+ }
+
+
+ @Override
+ public void setMaxBinaryMessageBufferSize(int max) {
+ checkState();
+ this.maxBinaryMessageBufferSize = max;
+ }
+
+
+ @Override
+ public int getMaxBinaryMessageBufferSize() {
+ checkState();
+ return maxBinaryMessageBufferSize;
+ }
+
+
+ @Override
+ public void setMaxTextMessageBufferSize(int max) {
+ checkState();
+ this.maxTextMessageBufferSize = max;
+ }
+
+
+ @Override
+ public int getMaxTextMessageBufferSize() {
+ checkState();
+ return maxTextMessageBufferSize;
+ }
+
+
+ @Override
+ public Set<Session> getOpenSessions() {
+ checkState();
+ return webSocketContainer.getOpenSessions(localEndpoint);
+ }
+
+
+ @Override
+ public RemoteEndpoint.Async getAsyncRemote() {
+ checkState();
+ return remoteEndpointAsync;
+ }
+
+
+ @Override
+ public RemoteEndpoint.Basic getBasicRemote() {
+ checkState();
+ return remoteEndpointBasic;
+ }
+
+
+ @Override
+ public void close() throws IOException {
+ close(new CloseReason(CloseCodes.NORMAL_CLOSURE, ""));
+ }
+
+
+ @Override
+ public void close(CloseReason closeReason) throws IOException {
+ doClose(closeReason, closeReason);
+ }
+
+
+ /**
+ * WebSocket 1.0. Section 2.1.5.
+ * Need internal close method as spec requires that the local endpoint
+ * receives a 1006 on timeout.
+ *
+ * @param closeReasonMessage The close reason to pass to the remote endpoint
+ * @param closeReasonLocal The close reason to pass to the local endpoint
+ */
+ public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal) {
+ // Double-checked locking. OK because state is volatile
+ if (state != State.OPEN) {
+ return;
+ }
+
+ synchronized (stateLock) {
+ if (state != State.OPEN) {
+ return;
+ }
+
+ if (log.isDebugEnabled()) {
+ log.debug(sm.getString("wsSession.doClose", id));
+ }
+ try {
+ wsRemoteEndpoint.setBatchingAllowed(false);
+ } catch (IOException e) {
+ log.warn(sm.getString("wsSession.flushFailOnClose"), e);
+ fireEndpointOnError(e);
+ }
+
+ state = State.OUTPUT_CLOSED;
+
+ sendCloseMessage(closeReasonMessage);
+ fireEndpointOnClose(closeReasonLocal);
+ }
+
+ IOException ioe = new IOException(sm.getString("wsSession.messageFailed"));
+ SendResult sr = new SendResult(ioe);
+ for (FutureToSendHandler f2sh : futures.keySet()) {
+ f2sh.onResult(sr);
+ }
+ }
+
+
+ /**
+ * Called when a close message is received. Should only ever happen once.
+ * Also called after a protocol error when the ProtocolHandler needs to
+ * force the closing of the connection.
+ *
+ * @param closeReason The reason contained within the received close
+ * message.
+ */
+ public void onClose(CloseReason closeReason) {
+
+ synchronized (stateLock) {
+ if (state != State.CLOSED) {
+ try {
+ wsRemoteEndpoint.setBatchingAllowed(false);
+ } catch (IOException e) {
+ log.warn(sm.getString("wsSession.flushFailOnClose"), e);
+ fireEndpointOnError(e);
+ }
+ if (state == State.OPEN) {
+ state = State.OUTPUT_CLOSED;
+ sendCloseMessage(closeReason);
+ fireEndpointOnClose(closeReason);
+ }
+ state = State.CLOSED;
+
+ // Close the socket
+ wsRemoteEndpoint.close();
+ }
+ }
+ }
+
+
+ public void onClose() {
+
+ synchronized (stateLock) {
+ if (state != State.CLOSED) {
+ try {
+ wsRemoteEndpoint.setBatchingAllowed(false);
+ } catch (IOException e) {
+ log.warn(sm.getString("wsSession.flushFailOnClose"), e);
+ fireEndpointOnError(e);
+ }
+ if (state == State.OPEN) {
+ state = State.OUTPUT_CLOSED;
+ fireEndpointOnClose(new CloseReason(
+ CloseReason.CloseCodes.NORMAL_CLOSURE, ""));
+ }
+ state = State.CLOSED;
+
+ // Close the socket
+ wsRemoteEndpoint.close();
+ }
+ }
+ }
+
+
+ private void fireEndpointOnClose(CloseReason closeReason) {
+
+ // Fire the onClose event
+ Throwable throwable = null;
+ InstanceManager instanceManager = webSocketContainer.getInstanceManager();
+ if (instanceManager == null) {
+ instanceManager = InstanceManagerBindings.get(applicationClassLoader);
+ }
+ Thread t = Thread.currentThread();
+ ClassLoader cl = t.getContextClassLoader();
+ t.setContextClassLoader(applicationClassLoader);
+ try {
+ localEndpoint.onClose(this, closeReason);
+ } catch (Throwable t1) {
+ ExceptionUtils.handleThrowable(t1);
+ throwable = t1;
+ } finally {
+ if (instanceManager != null) {
+ try {
+ instanceManager.destroyInstance(localEndpoint);
+ } catch (Throwable t2) {
+ ExceptionUtils.handleThrowable(t2);
+ if (throwable == null) {
+ throwable = t2;
+ }
+ }
+ }
+ t.setContextClassLoader(cl);
+ }
+
+ if (throwable != null) {
+ fireEndpointOnError(throwable);
+ }
+ }
+
+
+ private void fireEndpointOnError(Throwable throwable) {
+
+ // Fire the onError event
+ Thread t = Thread.currentThread();
+ ClassLoader cl = t.getContextClassLoader();
+ t.setContextClassLoader(applicationClassLoader);
+ try {
+ localEndpoint.onError(this, throwable);
+ } finally {
+ t.setContextClassLoader(cl);
+ }
+ }
+
+
+ private void sendCloseMessage(CloseReason closeReason) {
+ // 125 is maximum size for the payload of a control message
+ ByteBuffer msg = ByteBuffer.allocate(125);
+ CloseCode closeCode = closeReason.getCloseCode();
+ // CLOSED_ABNORMALLY should not be put on the wire
+ if (closeCode == CloseCodes.CLOSED_ABNORMALLY) {
+ // PROTOCOL_ERROR is probably better than GOING_AWAY here
+ msg.putShort((short) CloseCodes.PROTOCOL_ERROR.getCode());
+ } else {
+ msg.putShort((short) closeCode.getCode());
+ }
+
+ String reason = closeReason.getReasonPhrase();
+ if (reason != null && reason.length() > 0) {
+ appendCloseReasonWithTruncation(msg, reason);
+ }
+ msg.flip();
+ try {
+ wsRemoteEndpoint.sendMessageBlock(Constants.OPCODE_CLOSE, msg, true);
+ } catch (IOException | WritePendingException e) {
+ // Failed to send close message. Close the socket and let the caller
+ // deal with the Exception
+ if (log.isDebugEnabled()) {
+ log.debug(sm.getString("wsSession.sendCloseFail", id), e);
+ }
+ wsRemoteEndpoint.close();
+ // Failure to send a close message is not unexpected in the case of
+ // an abnormal closure (usually triggered by a failure to read/write
+ // from/to the client. In this case do not trigger the endpoint's
+ // error handling
+ if (closeCode != CloseCodes.CLOSED_ABNORMALLY) {
+ localEndpoint.onError(this, e);
+ }
+ } finally {
+ webSocketContainer.unregisterSession(localEndpoint, this);
+ }
+ }
+
+
+ /**
+ * Use protected so unit tests can access this method directly.
+ * @param msg The message
+ * @param reason The reason
+ */
+ protected static void appendCloseReasonWithTruncation(ByteBuffer msg, String reason) {
+ // Once the close code has been added there are a maximum of 123 bytes
+ // left for the reason phrase. If it is truncated then care needs to be
+ // taken to ensure the bytes are not truncated in the middle of a
+ // multi-byte UTF-8 character.
+ byte[] reasonBytes = reason.getBytes(StandardCharsets.UTF_8);
+
+ if (reasonBytes.length <= 123) {
+ // No need to truncate
+ msg.put(reasonBytes);
+ } else {
+ // Need to truncate
+ int remaining = 123 - ELLIPSIS_BYTES_LEN;
+ int pos = 0;
+ byte[] bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8);
+ while (remaining >= bytesNext.length) {
+ msg.put(bytesNext);
+ remaining -= bytesNext.length;
+ pos++;
+ bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8);
+ }
+ msg.put(ELLIPSIS_BYTES);
+ }
+ }
+
+
+ /**
+ * Make the session aware of a {@link FutureToSendHandler} that will need to
+ * be forcibly closed if the session closes before the
+ * {@link FutureToSendHandler} completes.
+ * @param f2sh The handler
+ */
+ protected void registerFuture(FutureToSendHandler f2sh) {
+ // Ideally, this code should sync on stateLock so that the correct
+ // action is taken based on the current state of the connection.
+ // However, a sync on stateLock can't be used here as it will create the
+ // possibility of a dead-lock. See BZ 61183.
+ // Therefore, a slightly less efficient approach is used.
+
+ // Always register the future.
+ futures.put(f2sh, f2sh);
+
+ if (state == State.OPEN) {
+ // The session is open. The future has been registered with the open
+ // session. Normal processing continues.
+ return;
+ }
+
+ // The session is closed. The future may or may not have been registered
+ // in time for it to be processed during session closure.
+
+ if (f2sh.isDone()) {
+ // The future has completed. It is not known if the future was
+ // completed normally by the I/O layer or in error by doClose(). It
+ // doesn't matter which. There is nothing more to do here.
+ return;
+ }
+
+ // The session is closed. The Future had not completed when last checked.
+ // There is a small timing window that means the Future may have been
+ // completed since the last check. There is also the possibility that
+ // the Future was not registered in time to be cleaned up during session
+ // close.
+ // Attempt to complete the Future with an error result as this ensures
+ // that the Future completes and any client code waiting on it does not
+ // hang. It is slightly inefficient since the Future may have been
+ // completed in another thread or another thread may be about to
+ // complete the Future but knowing if this is the case requires the sync
+ // on stateLock (see above).
+ // Note: If multiple attempts are made to complete the Future, the
+ // second and subsequent attempts are ignored.
+
+ IOException ioe = new IOException(sm.getString("wsSession.messageFailed"));
+ SendResult sr = new SendResult(ioe);
+ f2sh.onResult(sr);
+ }
+
+
+ /**
+ * Remove a {@link FutureToSendHandler} from the set of tracked instances.
+ * @param f2sh The handler
+ */
+ protected void unregisterFuture(FutureToSendHandler f2sh) {
+ futures.remove(f2sh);
+ }
+
+
+ @Override
+ public URI getRequestURI() {
+ checkState();
+ return requestUri;
+ }
+
+
+ @Override
+ public Map<String, List<String>> getRequestParameterMap() {
+ checkState();
+ return requestParameterMap;
+ }
+
+
+ @Override
+ public String getQueryString() {
+ checkState();
+ return queryString;
+ }
+
+
+ @Override
+ public Principal getUserPrincipal() {
+ checkState();
+ return userPrincipal;
+ }
+
+
+ @Override
+ public Map<String, String> getPathParameters() {
+ checkState();
+ return pathParameters;
+ }
+
+
+ @Override
+ public String getId() {
+ return id;
+ }
+
+
+ @Override
+ public Map<String, Object> getUserProperties() {
+ checkState();
+ return userProperties;
+ }
+
+
+ public Endpoint getLocal() {
+ return localEndpoint;
+ }
+
+
+ public String getHttpSessionId() {
+ return httpSessionId;
+ }
+
+ private ByteBuffer rawFragments;
+
+ public void processFrame(ByteBuffer buf, byte opCode, boolean last)
+ throws IOException
+ {
+ if (state == State.CLOSED) {
+ return;
+ }
+
+ if (opCode == Constants.OPCODE_CONTINUATION) {
+ opCode = startOpCode;
+
+ if (rawFragments != null && rawFragments.position() > 0) {
+ rawFragments.put(buf);
+ rawFragments.flip();
+ buf = rawFragments;
+ }
+ } else {
+ if (!last && (opCode == Constants.OPCODE_BINARY ||
+ opCode == Constants.OPCODE_TEXT)) {
+ startOpCode = opCode;
+
+ if (rawFragments != null) {
+ rawFragments.clear();
+ }
+ }
+ }
+
+ if (last) {
+ startOpCode = Constants.OPCODE_CONTINUATION;
+ }
+
+ if (opCode == Constants.OPCODE_PONG) {
+ if (pongMessageHandler != null) {
+ final ByteBuffer b = buf;
+
+ PongMessage pongMessage = new PongMessage() {
+ @Override
+ public ByteBuffer getApplicationData() {
+ return b;
+ }
+ };
+
+ pongMessageHandler.onMessage(pongMessage);
+ }
+ }
+
+ if (opCode == Constants.OPCODE_CLOSE) {
+ CloseReason closeReason;
+
+ if (buf.remaining() >= 2) {
+ short closeCode = buf.order(ByteOrder.BIG_ENDIAN).getShort();
+
+ closeReason = new CloseReason(
+ CloseReason.CloseCodes.getCloseCode(closeCode),
+ buf.asCharBuffer().toString());
+ } else {
+ closeReason = new CloseReason(
+ CloseReason.CloseCodes.NORMAL_CLOSURE, "");
+ }
+
+ onClose(closeReason);
+ }
+
+ if (opCode == Constants.OPCODE_BINARY) {
+ onMessage(buf, last);
+ }
+
+ if (opCode == Constants.OPCODE_TEXT) {
+ if (messageBufferText.position() == 0 && maxTextMessageBufferSize != messageBufferText.capacity()) {
+ messageBufferText = CharBuffer.allocate(maxTextMessageBufferSize);
+ }
+
+ CoderResult cr = utf8DecoderMessage.decode(buf, messageBufferText, last);
+ if (cr.isError()) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.NOT_CONSISTENT,
+ sm.getString("wsFrame.invalidUtf8")));
+ } else if (cr.isOverflow()) {
+ // Ran out of space in text buffer - flush it
+ if (hasTextPartial()) {
+ do {
+ onMessage(messageBufferText, false);
+
+ cr = utf8DecoderMessage.decode(buf, messageBufferText, last);
+ } while (cr.isOverflow());
+ } else {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.TOO_BIG,
+ sm.getString("wsFrame.textMessageTooBig")));
+ }
+ } else if (cr.isUnderflow() && !last) {
+ updateRawFragments(buf, last);
+
+ if (hasTextPartial()) {
+ onMessage(messageBufferText, false);
+ }
+
+ return;
+ }
+
+ if (last) {
+ utf8DecoderMessage.reset();
+ }
+
+ updateRawFragments(buf, last);
+
+ onMessage(messageBufferText, last);
+ }
+ }
+
+
+ private boolean hasTextPartial() {
+ return textMessageHandler instanceof MessageHandler.Partial<?>;
+ }
+
+
+ private void onMessage(CharBuffer buf, boolean last) throws IOException {
+ buf.flip();
+ try {
+ onMessage(buf.toString(), last);
+ } catch (Throwable t) {
+ handleThrowableOnSend(t);
+ } finally {
+ buf.clear();
+ }
+ }
+
+
+ private void updateRawFragments(ByteBuffer buf, boolean last) {
+ if (!last && buf.remaining() > 0) {
+ if (buf == rawFragments) {
+ buf.compact();
+ } else {
+ if (rawFragments == null || (rawFragments.position() == 0 && maxTextMessageBufferSize != rawFragments.capacity())) {
+ rawFragments = ByteBuffer.allocateDirect(maxTextMessageBufferSize);
+ }
+ rawFragments.put(buf);
+ }
+ } else {
+ if (rawFragments != null) {
+ rawFragments.clear();
+ }
+ }
+ }
+
+
+ @SuppressWarnings("unchecked")
+ public void onMessage(String text, boolean last) {
+ if (hasTextPartial()) {
+ ((MessageHandler.Partial<String>) textMessageHandler).onMessage(text, last);
+ } else {
+ // Caller ensures last == true if this branch is used
+ ((MessageHandler.Whole<String>) textMessageHandler).onMessage(text);
+ }
+ }
+
+
+ @SuppressWarnings("unchecked")
+ public void onMessage(ByteBuffer buf, boolean last)
+ throws IOException
+ {
+ if (binaryMessageHandler instanceof MessageHandler.Partial<?>) {
+ ((MessageHandler.Partial<ByteBuffer>) binaryMessageHandler).onMessage(buf, last);
+ } else {
+ if (last && (binaryBuffer == null || binaryBuffer.position() == 0)) {
+ ((MessageHandler.Whole<ByteBuffer>) binaryMessageHandler).onMessage(buf);
+ return;
+ }
+
+ if (binaryBuffer == null ||
+ (binaryBuffer.position() == 0 && binaryBuffer.capacity() != maxBinaryMessageBufferSize))
+ {
+ binaryBuffer = ByteBuffer.allocateDirect(maxBinaryMessageBufferSize);
+ }
+
+ if (binaryBuffer.remaining() < buf.remaining()) {
+ throw new WsIOException(new CloseReason(
+ CloseCodes.TOO_BIG,
+ sm.getString("wsFrame.textMessageTooBig")));
+ }
+
+ binaryBuffer.put(buf);
+
+ if (last) {
+ binaryBuffer.flip();
+ try {
+ ((MessageHandler.Whole<ByteBuffer>) binaryMessageHandler).onMessage(binaryBuffer);
+ } finally {
+ binaryBuffer.clear();
+ }
+ }
+ }
+ }
+
+
+ private void handleThrowableOnSend(Throwable t) throws WsIOException {
+ ExceptionUtils.handleThrowable(t);
+ getLocal().onError(this, t);
+ CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY,
+ sm.getString("wsFrame.ioeTriggeredClose"));
+ throw new WsIOException(cr);
+ }
+
+
+ protected MessageHandler getTextMessageHandler() {
+ return textMessageHandler;
+ }
+
+
+ protected MessageHandler getBinaryMessageHandler() {
+ return binaryMessageHandler;
+ }
+
+
+ protected MessageHandler.Whole<PongMessage> getPongMessageHandler() {
+ return pongMessageHandler;
+ }
+
+
+ protected void updateLastActive() {
+ lastActive = System.currentTimeMillis();
+ }
+
+
+ protected void checkExpiration() {
+ long timeout = maxIdleTimeout;
+ if (timeout < 1) {
+ return;
+ }
+
+ if (System.currentTimeMillis() - lastActive > timeout) {
+ String msg = sm.getString("wsSession.timeout", getId());
+ if (log.isDebugEnabled()) {
+ log.debug(msg);
+ }
+ doClose(new CloseReason(CloseCodes.GOING_AWAY, msg),
+ new CloseReason(CloseCodes.CLOSED_ABNORMALLY, msg));
+ }
+ }
+
+
+ private void checkState() {
+ if (state == State.CLOSED) {
+ /*
+ * As per RFC 6455, a WebSocket connection is considered to be
+ * closed once a peer has sent and received a WebSocket close frame.
+ */
+ throw new IllegalStateException(sm.getString("wsSession.closed", id));
+ }
+ }
+
+ private enum State {
+ OPEN,
+ OUTPUT_CLOSED,
+ CLOSED
+ }
+}
diff --git a/src/java/nginx/unit/websocket/WsWebSocketContainer.java b/src/java/nginx/unit/websocket/WsWebSocketContainer.java
new file mode 100644
index 00000000..282665ef
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsWebSocketContainer.java
@@ -0,0 +1,1123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.EOFException;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.InetSocketAddress;
+import java.net.Proxy;
+import java.net.ProxySelector;
+import java.net.SocketAddress;
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.nio.ByteBuffer;
+import java.nio.channels.AsynchronousChannelGroup;
+import java.nio.channels.AsynchronousSocketChannel;
+import java.nio.charset.StandardCharsets;
+import java.security.KeyStore;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import javax.net.ssl.SSLContext;
+import javax.net.ssl.SSLEngine;
+import javax.net.ssl.SSLException;
+import javax.net.ssl.SSLParameters;
+import javax.net.ssl.TrustManagerFactory;
+import javax.websocket.ClientEndpoint;
+import javax.websocket.ClientEndpointConfig;
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.DeploymentException;
+import javax.websocket.Endpoint;
+import javax.websocket.Extension;
+import javax.websocket.HandshakeResponse;
+import javax.websocket.Session;
+import javax.websocket.WebSocketContainer;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.InstanceManager;
+import org.apache.tomcat.util.buf.StringUtils;
+import org.apache.tomcat.util.codec.binary.Base64;
+import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
+import org.apache.tomcat.util.res.StringManager;
+import nginx.unit.websocket.pojo.PojoEndpointClient;
+
+public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess {
+
+ private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class);
+ private static final Random RANDOM = new Random();
+ private static final byte[] CRLF = new byte[] { 13, 10 };
+
+ private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1);
+ private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1);
+ private static final byte[] HTTP_VERSION_BYTES =
+ " HTTP/1.1\r\n".getBytes(StandardCharsets.ISO_8859_1);
+
+ private volatile AsynchronousChannelGroup asynchronousChannelGroup = null;
+ private final Object asynchronousChannelGroupLock = new Object();
+
+ private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static
+ private final Map<Endpoint, Set<WsSession>> endpointSessionMap =
+ new HashMap<>();
+ private final Map<WsSession,WsSession> sessions = new ConcurrentHashMap<>();
+ private final Object endPointSessionMapLock = new Object();
+
+ private long defaultAsyncTimeout = -1;
+ private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
+ private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE;
+ private volatile long defaultMaxSessionIdleTimeout = 0;
+ private int backgroundProcessCount = 0;
+ private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD;
+
+ private InstanceManager instanceManager;
+
+ InstanceManager getInstanceManager() {
+ return instanceManager;
+ }
+
+ protected void setInstanceManager(InstanceManager instanceManager) {
+ this.instanceManager = instanceManager;
+ }
+
+ @Override
+ public Session connectToServer(Object pojo, URI path)
+ throws DeploymentException {
+
+ ClientEndpoint annotation =
+ pojo.getClass().getAnnotation(ClientEndpoint.class);
+ if (annotation == null) {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.missingAnnotation",
+ pojo.getClass().getName()));
+ }
+
+ Endpoint ep = new PojoEndpointClient(pojo, Arrays.asList(annotation.decoders()));
+
+ Class<? extends ClientEndpointConfig.Configurator> configuratorClazz =
+ annotation.configurator();
+
+ ClientEndpointConfig.Configurator configurator = null;
+ if (!ClientEndpointConfig.Configurator.class.equals(
+ configuratorClazz)) {
+ try {
+ configurator = configuratorClazz.getConstructor().newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.defaultConfiguratorFail"), e);
+ }
+ }
+
+ ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create();
+ // Avoid NPE when using RI API JAR - see BZ 56343
+ if (configurator != null) {
+ builder.configurator(configurator);
+ }
+ ClientEndpointConfig config = builder.
+ decoders(Arrays.asList(annotation.decoders())).
+ encoders(Arrays.asList(annotation.encoders())).
+ preferredSubprotocols(Arrays.asList(annotation.subprotocols())).
+ build();
+ return connectToServer(ep, config, path);
+ }
+
+
+ @Override
+ public Session connectToServer(Class<?> annotatedEndpointClass, URI path)
+ throws DeploymentException {
+
+ Object pojo;
+ try {
+ pojo = annotatedEndpointClass.getConstructor().newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.endpointCreateFail",
+ annotatedEndpointClass.getName()), e);
+ }
+
+ return connectToServer(pojo, path);
+ }
+
+
+ @Override
+ public Session connectToServer(Class<? extends Endpoint> clazz,
+ ClientEndpointConfig clientEndpointConfiguration, URI path)
+ throws DeploymentException {
+
+ Endpoint endpoint;
+ try {
+ endpoint = clazz.getConstructor().newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.endpointCreateFail", clazz.getName()),
+ e);
+ }
+
+ return connectToServer(endpoint, clientEndpointConfiguration, path);
+ }
+
+
+ @Override
+ public Session connectToServer(Endpoint endpoint,
+ ClientEndpointConfig clientEndpointConfiguration, URI path)
+ throws DeploymentException {
+ return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, new HashSet<>());
+ }
+
+ private Session connectToServerRecursive(Endpoint endpoint,
+ ClientEndpointConfig clientEndpointConfiguration, URI path,
+ Set<URI> redirectSet)
+ throws DeploymentException {
+
+ boolean secure = false;
+ ByteBuffer proxyConnect = null;
+ URI proxyPath;
+
+ // Validate scheme (and build proxyPath)
+ String scheme = path.getScheme();
+ if ("ws".equalsIgnoreCase(scheme)) {
+ proxyPath = URI.create("http" + path.toString().substring(2));
+ } else if ("wss".equalsIgnoreCase(scheme)) {
+ proxyPath = URI.create("https" + path.toString().substring(3));
+ secure = true;
+ } else {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.pathWrongScheme", scheme));
+ }
+
+ // Validate host
+ String host = path.getHost();
+ if (host == null) {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.pathNoHost"));
+ }
+ int port = path.getPort();
+
+ SocketAddress sa = null;
+
+ // Check to see if a proxy is configured. Javadoc indicates return value
+ // will never be null
+ List<Proxy> proxies = ProxySelector.getDefault().select(proxyPath);
+ Proxy selectedProxy = null;
+ for (Proxy proxy : proxies) {
+ if (proxy.type().equals(Proxy.Type.HTTP)) {
+ sa = proxy.address();
+ if (sa instanceof InetSocketAddress) {
+ InetSocketAddress inet = (InetSocketAddress) sa;
+ if (inet.isUnresolved()) {
+ sa = new InetSocketAddress(inet.getHostName(), inet.getPort());
+ }
+ }
+ selectedProxy = proxy;
+ break;
+ }
+ }
+
+ // If the port is not explicitly specified, compute it based on the
+ // scheme
+ if (port == -1) {
+ if ("ws".equalsIgnoreCase(scheme)) {
+ port = 80;
+ } else {
+ // Must be wss due to scheme validation above
+ port = 443;
+ }
+ }
+
+ // If sa is null, no proxy is configured so need to create sa
+ if (sa == null) {
+ sa = new InetSocketAddress(host, port);
+ } else {
+ proxyConnect = createProxyRequest(host, port);
+ }
+
+ // Create the initial HTTP request to open the WebSocket connection
+ Map<String, List<String>> reqHeaders = createRequestHeaders(host, port,
+ clientEndpointConfiguration);
+ clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders);
+ if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null
+ && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) {
+ List<String> originValues = new ArrayList<>(1);
+ originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE);
+ reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues);
+ }
+ ByteBuffer request = createRequest(path, reqHeaders);
+
+ AsynchronousSocketChannel socketChannel;
+ try {
+ socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup());
+ } catch (IOException ioe) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.asynchronousSocketChannelFail"), ioe);
+ }
+
+ Map<String,Object> userProperties = clientEndpointConfiguration.getUserProperties();
+
+ // Get the connection timeout
+ long timeout = Constants.IO_TIMEOUT_MS_DEFAULT;
+ String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY);
+ if (timeoutValue != null) {
+ timeout = Long.valueOf(timeoutValue).intValue();
+ }
+
+ // Set-up
+ // Same size as the WsFrame input buffer
+ ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize());
+ String subProtocol;
+ boolean success = false;
+ List<Extension> extensionsAgreed = new ArrayList<>();
+ Transformation transformation = null;
+
+ // Open the connection
+ Future<Void> fConnect = socketChannel.connect(sa);
+ AsyncChannelWrapper channel = null;
+
+ if (proxyConnect != null) {
+ try {
+ fConnect.get(timeout, TimeUnit.MILLISECONDS);
+ // Proxy CONNECT is clear text
+ channel = new AsyncChannelWrapperNonSecure(socketChannel);
+ writeRequest(channel, proxyConnect, timeout);
+ HttpResponse httpResponse = processResponse(response, channel, timeout);
+ if (httpResponse.getStatus() != 200) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.proxyConnectFail", selectedProxy,
+ Integer.toString(httpResponse.getStatus())));
+ }
+ } catch (TimeoutException | InterruptedException | ExecutionException |
+ EOFException e) {
+ if (channel != null) {
+ channel.close();
+ }
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.httpRequestFailed"), e);
+ }
+ }
+
+ if (secure) {
+ // Regardless of whether a non-secure wrapper was created for a
+ // proxy CONNECT, need to use TLS from this point on so wrap the
+ // original AsynchronousSocketChannel
+ SSLEngine sslEngine = createSSLEngine(userProperties, host, port);
+ channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine);
+ } else if (channel == null) {
+ // Only need to wrap as this point if it wasn't wrapped to process a
+ // proxy CONNECT
+ channel = new AsyncChannelWrapperNonSecure(socketChannel);
+ }
+
+ try {
+ fConnect.get(timeout, TimeUnit.MILLISECONDS);
+
+ Future<Void> fHandshake = channel.handshake();
+ fHandshake.get(timeout, TimeUnit.MILLISECONDS);
+
+ writeRequest(channel, request, timeout);
+
+ HttpResponse httpResponse = processResponse(response, channel, timeout);
+
+ // Check maximum permitted redirects
+ int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT;
+ String maxRedirectsValue =
+ (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY);
+ if (maxRedirectsValue != null) {
+ maxRedirects = Integer.parseInt(maxRedirectsValue);
+ }
+
+ if (httpResponse.status != 101) {
+ if(isRedirectStatus(httpResponse.status)){
+ List<String> locationHeader =
+ httpResponse.getHandshakeResponse().getHeaders().get(
+ Constants.LOCATION_HEADER_NAME);
+
+ if (locationHeader == null || locationHeader.isEmpty() ||
+ locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.missingLocationHeader",
+ Integer.toString(httpResponse.status)));
+ }
+
+ URI redirectLocation = URI.create(locationHeader.get(0)).normalize();
+
+ if (!redirectLocation.isAbsolute()) {
+ redirectLocation = path.resolve(redirectLocation);
+ }
+
+ String redirectScheme = redirectLocation.getScheme().toLowerCase();
+
+ if (redirectScheme.startsWith("http")) {
+ redirectLocation = new URI(redirectScheme.replace("http", "ws"),
+ redirectLocation.getUserInfo(), redirectLocation.getHost(),
+ redirectLocation.getPort(), redirectLocation.getPath(),
+ redirectLocation.getQuery(), redirectLocation.getFragment());
+ }
+
+ if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.redirectThreshold", redirectLocation,
+ Integer.toString(redirectSet.size()),
+ Integer.toString(maxRedirects)));
+ }
+
+ return connectToServerRecursive(endpoint, clientEndpointConfiguration, redirectLocation, redirectSet);
+
+ }
+
+ else if (httpResponse.status == 401) {
+
+ if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.failedAuthentication",
+ Integer.valueOf(httpResponse.status)));
+ }
+
+ List<String> wwwAuthenticateHeaders = httpResponse.getHandshakeResponse()
+ .getHeaders().get(Constants.WWW_AUTHENTICATE_HEADER_NAME);
+
+ if (wwwAuthenticateHeaders == null || wwwAuthenticateHeaders.isEmpty() ||
+ wwwAuthenticateHeaders.get(0) == null || wwwAuthenticateHeaders.get(0).isEmpty()) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.missingWWWAuthenticateHeader",
+ Integer.toString(httpResponse.status)));
+ }
+
+ String authScheme = wwwAuthenticateHeaders.get(0).split("\\s+", 2)[0];
+ String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1)
+ .split("\\s", 3)[1];
+
+ Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme);
+
+ if (auth == null) {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.unsupportedAuthScheme",
+ Integer.valueOf(httpResponse.status), authScheme));
+ }
+
+ userProperties.put(Constants.AUTHORIZATION_HEADER_NAME, auth.getAuthorization(
+ requestUri, wwwAuthenticateHeaders.get(0), userProperties));
+
+ return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, redirectSet);
+
+ }
+
+ else {
+ throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus",
+ Integer.toString(httpResponse.status)));
+ }
+ }
+ HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse();
+ clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse);
+
+ // Sub-protocol
+ List<String> protocolHeaders = handshakeResponse.getHeaders().get(
+ Constants.WS_PROTOCOL_HEADER_NAME);
+ if (protocolHeaders == null || protocolHeaders.size() == 0) {
+ subProtocol = null;
+ } else if (protocolHeaders.size() == 1) {
+ subProtocol = protocolHeaders.get(0);
+ } else {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.invalidSubProtocol"));
+ }
+
+ // Extensions
+ // Should normally only be one header but handle the case of
+ // multiple headers
+ List<String> extHeaders = handshakeResponse.getHeaders().get(
+ Constants.WS_EXTENSIONS_HEADER_NAME);
+ if (extHeaders != null) {
+ for (String extHeader : extHeaders) {
+ Util.parseExtensionHeader(extensionsAgreed, extHeader);
+ }
+ }
+
+ // Build the transformations
+ TransformationFactory factory = TransformationFactory.getInstance();
+ for (Extension extension : extensionsAgreed) {
+ List<List<Extension.Parameter>> wrapper = new ArrayList<>(1);
+ wrapper.add(extension.getParameters());
+ Transformation t = factory.create(extension.getName(), wrapper, false);
+ if (t == null) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.invalidExtensionParameters"));
+ }
+ if (transformation == null) {
+ transformation = t;
+ } else {
+ transformation.setNext(t);
+ }
+ }
+
+ success = true;
+ } catch (ExecutionException | InterruptedException | SSLException |
+ EOFException | TimeoutException | URISyntaxException | AuthenticationException e) {
+ throw new DeploymentException(
+ sm.getString("wsWebSocketContainer.httpRequestFailed"), e);
+ } finally {
+ if (!success) {
+ channel.close();
+ }
+ }
+
+ // Switch to WebSocket
+ WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel);
+
+ WsSession wsSession = new WsSession(endpoint, wsRemoteEndpointClient,
+ this, null, null, null, null, null, extensionsAgreed,
+ subProtocol, Collections.<String,String>emptyMap(), secure,
+ clientEndpointConfiguration, null);
+
+ WsFrameClient wsFrameClient = new WsFrameClient(response, channel,
+ wsSession, transformation);
+ // WsFrame adds the necessary final transformations. Copy the
+ // completed transformation chain to the remote end point.
+ wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation());
+
+ endpoint.onOpen(wsSession, clientEndpointConfiguration);
+ registerSession(endpoint, wsSession);
+
+ /* It is possible that the server sent one or more messages as soon as
+ * the WebSocket connection was established. Depending on the exact
+ * timing of when those messages were sent they could be sat in the
+ * input buffer waiting to be read and will not trigger a "data
+ * available to read" event. Therefore, it is necessary to process the
+ * input buffer here. Note that this happens on the current thread which
+ * means that this thread will be used for any onMessage notifications.
+ * This is a special case. Subsequent "data available to read" events
+ * will be handled by threads from the AsyncChannelGroup's executor.
+ */
+ wsFrameClient.startInputProcessing();
+
+ return wsSession;
+ }
+
+
+ private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request,
+ long timeout) throws TimeoutException, InterruptedException, ExecutionException {
+ int toWrite = request.limit();
+
+ Future<Integer> fWrite = channel.write(request);
+ Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS);
+ toWrite -= thisWrite.intValue();
+
+ while (toWrite > 0) {
+ fWrite = channel.write(request);
+ thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS);
+ toWrite -= thisWrite.intValue();
+ }
+ }
+
+
+ private static boolean isRedirectStatus(int httpResponseCode) {
+
+ boolean isRedirect = false;
+
+ switch (httpResponseCode) {
+ case Constants.MULTIPLE_CHOICES:
+ case Constants.MOVED_PERMANENTLY:
+ case Constants.FOUND:
+ case Constants.SEE_OTHER:
+ case Constants.USE_PROXY:
+ case Constants.TEMPORARY_REDIRECT:
+ isRedirect = true;
+ break;
+ default:
+ break;
+ }
+
+ return isRedirect;
+ }
+
+
+ private static ByteBuffer createProxyRequest(String host, int port) {
+ StringBuilder request = new StringBuilder();
+ request.append("CONNECT ");
+ request.append(host);
+ request.append(':');
+ request.append(port);
+
+ request.append(" HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keepalive\r\nHost: ");
+ request.append(host);
+ request.append(':');
+ request.append(port);
+
+ request.append("\r\n\r\n");
+
+ byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1);
+ return ByteBuffer.wrap(bytes);
+ }
+
+ protected void registerSession(Endpoint endpoint, WsSession wsSession) {
+
+ if (!wsSession.isOpen()) {
+ // The session was closed during onOpen. No need to register it.
+ return;
+ }
+ synchronized (endPointSessionMapLock) {
+ if (endpointSessionMap.size() == 0) {
+ BackgroundProcessManager.getInstance().register(this);
+ }
+ Set<WsSession> wsSessions = endpointSessionMap.get(endpoint);
+ if (wsSessions == null) {
+ wsSessions = new HashSet<>();
+ endpointSessionMap.put(endpoint, wsSessions);
+ }
+ wsSessions.add(wsSession);
+ }
+ sessions.put(wsSession, wsSession);
+ }
+
+
+ protected void unregisterSession(Endpoint endpoint, WsSession wsSession) {
+
+ synchronized (endPointSessionMapLock) {
+ Set<WsSession> wsSessions = endpointSessionMap.get(endpoint);
+ if (wsSessions != null) {
+ wsSessions.remove(wsSession);
+ if (wsSessions.size() == 0) {
+ endpointSessionMap.remove(endpoint);
+ }
+ }
+ if (endpointSessionMap.size() == 0) {
+ BackgroundProcessManager.getInstance().unregister(this);
+ }
+ }
+ sessions.remove(wsSession);
+ }
+
+
+ Set<Session> getOpenSessions(Endpoint endpoint) {
+ HashSet<Session> result = new HashSet<>();
+ synchronized (endPointSessionMapLock) {
+ Set<WsSession> sessions = endpointSessionMap.get(endpoint);
+ if (sessions != null) {
+ result.addAll(sessions);
+ }
+ }
+ return result;
+ }
+
+ private static Map<String, List<String>> createRequestHeaders(String host, int port,
+ ClientEndpointConfig clientEndpointConfiguration) {
+
+ Map<String, List<String>> headers = new HashMap<>();
+ List<Extension> extensions = clientEndpointConfiguration.getExtensions();
+ List<String> subProtocols = clientEndpointConfiguration.getPreferredSubprotocols();
+ Map<String, Object> userProperties = clientEndpointConfiguration.getUserProperties();
+
+ if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) {
+ List<String> authValues = new ArrayList<>(1);
+ authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME));
+ headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues);
+ }
+
+ // Host header
+ List<String> hostValues = new ArrayList<>(1);
+ if (port == -1) {
+ hostValues.add(host);
+ } else {
+ hostValues.add(host + ':' + port);
+ }
+
+ headers.put(Constants.HOST_HEADER_NAME, hostValues);
+
+ // Upgrade header
+ List<String> upgradeValues = new ArrayList<>(1);
+ upgradeValues.add(Constants.UPGRADE_HEADER_VALUE);
+ headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues);
+
+ // Connection header
+ List<String> connectionValues = new ArrayList<>(1);
+ connectionValues.add(Constants.CONNECTION_HEADER_VALUE);
+ headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues);
+
+ // WebSocket version header
+ List<String> wsVersionValues = new ArrayList<>(1);
+ wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE);
+ headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues);
+
+ // WebSocket key
+ List<String> wsKeyValues = new ArrayList<>(1);
+ wsKeyValues.add(generateWsKeyValue());
+ headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues);
+
+ // WebSocket sub-protocols
+ if (subProtocols != null && subProtocols.size() > 0) {
+ headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols);
+ }
+
+ // WebSocket extensions
+ if (extensions != null && extensions.size() > 0) {
+ headers.put(Constants.WS_EXTENSIONS_HEADER_NAME,
+ generateExtensionHeaders(extensions));
+ }
+
+ return headers;
+ }
+
+
+ private static List<String> generateExtensionHeaders(List<Extension> extensions) {
+ List<String> result = new ArrayList<>(extensions.size());
+ for (Extension extension : extensions) {
+ StringBuilder header = new StringBuilder();
+ header.append(extension.getName());
+ for (Extension.Parameter param : extension.getParameters()) {
+ header.append(';');
+ header.append(param.getName());
+ String value = param.getValue();
+ if (value != null && value.length() > 0) {
+ header.append('=');
+ header.append(value);
+ }
+ }
+ result.add(header.toString());
+ }
+ return result;
+ }
+
+
+ private static String generateWsKeyValue() {
+ byte[] keyBytes = new byte[16];
+ RANDOM.nextBytes(keyBytes);
+ return Base64.encodeBase64String(keyBytes);
+ }
+
+
+ private static ByteBuffer createRequest(URI uri, Map<String,List<String>> reqHeaders) {
+ ByteBuffer result = ByteBuffer.allocate(4 * 1024);
+
+ // Request line
+ result.put(GET_BYTES);
+ if (null == uri.getPath() || "".equals(uri.getPath())) {
+ result.put(ROOT_URI_BYTES);
+ } else {
+ result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1));
+ }
+ String query = uri.getRawQuery();
+ if (query != null) {
+ result.put((byte) '?');
+ result.put(query.getBytes(StandardCharsets.ISO_8859_1));
+ }
+ result.put(HTTP_VERSION_BYTES);
+
+ // Headers
+ for (Entry<String, List<String>> entry : reqHeaders.entrySet()) {
+ result = addHeader(result, entry.getKey(), entry.getValue());
+ }
+
+ // Terminating CRLF
+ result.put(CRLF);
+
+ result.flip();
+
+ return result;
+ }
+
+
+ private static ByteBuffer addHeader(ByteBuffer result, String key, List<String> values) {
+ if (values.isEmpty()) {
+ return result;
+ }
+
+ result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1));
+ result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1));
+ result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1));
+ result = putWithExpand(result, CRLF);
+
+ return result;
+ }
+
+
+ private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) {
+ if (bytes.length > input.remaining()) {
+ int newSize;
+ if (bytes.length > input.capacity()) {
+ newSize = 2 * bytes.length;
+ } else {
+ newSize = input.capacity() * 2;
+ }
+ ByteBuffer expanded = ByteBuffer.allocate(newSize);
+ input.flip();
+ expanded.put(input);
+ input = expanded;
+ }
+ return input.put(bytes);
+ }
+
+
+ /**
+ * Process response, blocking until HTTP response has been fully received.
+ * @throws ExecutionException
+ * @throws InterruptedException
+ * @throws DeploymentException
+ * @throws TimeoutException
+ */
+ private HttpResponse processResponse(ByteBuffer response,
+ AsyncChannelWrapper channel, long timeout) throws InterruptedException,
+ ExecutionException, DeploymentException, EOFException,
+ TimeoutException {
+
+ Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>();
+
+ int status = 0;
+ boolean readStatus = false;
+ boolean readHeaders = false;
+ String line = null;
+ while (!readHeaders) {
+ // On entering loop buffer will be empty and at the start of a new
+ // loop the buffer will have been fully read.
+ response.clear();
+ // Blocking read
+ Future<Integer> read = channel.read(response);
+ Integer bytesRead = read.get(timeout, TimeUnit.MILLISECONDS);
+ if (bytesRead.intValue() == -1) {
+ throw new EOFException();
+ }
+ response.flip();
+ while (response.hasRemaining() && !readHeaders) {
+ if (line == null) {
+ line = readLine(response);
+ } else {
+ line += readLine(response);
+ }
+ if ("\r\n".equals(line)) {
+ readHeaders = true;
+ } else if (line.endsWith("\r\n")) {
+ if (readStatus) {
+ parseHeaders(line, headers);
+ } else {
+ status = parseStatus(line);
+ readStatus = true;
+ }
+ line = null;
+ }
+ }
+ }
+
+ return new HttpResponse(status, new WsHandshakeResponse(headers));
+ }
+
+
+ private int parseStatus(String line) throws DeploymentException {
+ // This client only understands HTTP 1.
+ // RFC2616 is case specific
+ String[] parts = line.trim().split(" ");
+ // CONNECT for proxy may return a 1.0 response
+ if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.invalidStatus", line));
+ }
+ try {
+ return Integer.parseInt(parts[1]);
+ } catch (NumberFormatException nfe) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.invalidStatus", line));
+ }
+ }
+
+
+ private void parseHeaders(String line, Map<String,List<String>> headers) {
+ // Treat headers as single values by default.
+
+ int index = line.indexOf(':');
+ if (index == -1) {
+ log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line));
+ return;
+ }
+ // Header names are case insensitive so always use lower case
+ String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH);
+ // Multi-value headers are stored as a single header and the client is
+ // expected to handle splitting into individual values
+ String headerValue = line.substring(index + 1).trim();
+
+ List<String> values = headers.get(headerName);
+ if (values == null) {
+ values = new ArrayList<>(1);
+ headers.put(headerName, values);
+ }
+ values.add(headerValue);
+ }
+
+ private String readLine(ByteBuffer response) {
+ // All ISO-8859-1
+ StringBuilder sb = new StringBuilder();
+
+ char c = 0;
+ while (response.hasRemaining()) {
+ c = (char) response.get();
+ sb.append(c);
+ if (c == 10) {
+ break;
+ }
+ }
+
+ return sb.toString();
+ }
+
+
+ private SSLEngine createSSLEngine(Map<String,Object> userProperties, String host, int port)
+ throws DeploymentException {
+
+ try {
+ // See if a custom SSLContext has been provided
+ SSLContext sslContext =
+ (SSLContext) userProperties.get(Constants.SSL_CONTEXT_PROPERTY);
+
+ if (sslContext == null) {
+ // Create the SSL Context
+ sslContext = SSLContext.getInstance("TLS");
+
+ // Trust store
+ String sslTrustStoreValue =
+ (String) userProperties.get(Constants.SSL_TRUSTSTORE_PROPERTY);
+ if (sslTrustStoreValue != null) {
+ String sslTrustStorePwdValue = (String) userProperties.get(
+ Constants.SSL_TRUSTSTORE_PWD_PROPERTY);
+ if (sslTrustStorePwdValue == null) {
+ sslTrustStorePwdValue = Constants.SSL_TRUSTSTORE_PWD_DEFAULT;
+ }
+
+ File keyStoreFile = new File(sslTrustStoreValue);
+ KeyStore ks = KeyStore.getInstance("JKS");
+ try (InputStream is = new FileInputStream(keyStoreFile)) {
+ ks.load(is, sslTrustStorePwdValue.toCharArray());
+ }
+
+ TrustManagerFactory tmf = TrustManagerFactory.getInstance(
+ TrustManagerFactory.getDefaultAlgorithm());
+ tmf.init(ks);
+
+ sslContext.init(null, tmf.getTrustManagers(), null);
+ } else {
+ sslContext.init(null, null, null);
+ }
+ }
+
+ SSLEngine engine = sslContext.createSSLEngine(host, port);
+
+ String sslProtocolsValue =
+ (String) userProperties.get(Constants.SSL_PROTOCOLS_PROPERTY);
+ if (sslProtocolsValue != null) {
+ engine.setEnabledProtocols(sslProtocolsValue.split(","));
+ }
+
+ engine.setUseClientMode(true);
+
+ // Enable host verification
+ // Start with current settings (returns a copy)
+ SSLParameters sslParams = engine.getSSLParameters();
+ // Use HTTPS since WebSocket starts over HTTP(S)
+ sslParams.setEndpointIdentificationAlgorithm("HTTPS");
+ // Write the parameters back
+ engine.setSSLParameters(sslParams);
+
+ return engine;
+ } catch (Exception e) {
+ throw new DeploymentException(sm.getString(
+ "wsWebSocketContainer.sslEngineFail"), e);
+ }
+ }
+
+
+ @Override
+ public long getDefaultMaxSessionIdleTimeout() {
+ return defaultMaxSessionIdleTimeout;
+ }
+
+
+ @Override
+ public void setDefaultMaxSessionIdleTimeout(long timeout) {
+ this.defaultMaxSessionIdleTimeout = timeout;
+ }
+
+
+ @Override
+ public int getDefaultMaxBinaryMessageBufferSize() {
+ return maxBinaryMessageBufferSize;
+ }
+
+
+ @Override
+ public void setDefaultMaxBinaryMessageBufferSize(int max) {
+ maxBinaryMessageBufferSize = max;
+ }
+
+
+ @Override
+ public int getDefaultMaxTextMessageBufferSize() {
+ return maxTextMessageBufferSize;
+ }
+
+
+ @Override
+ public void setDefaultMaxTextMessageBufferSize(int max) {
+ maxTextMessageBufferSize = max;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * Currently, this implementation does not support any extensions.
+ */
+ @Override
+ public Set<Extension> getInstalledExtensions() {
+ return Collections.emptySet();
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * The default value for this implementation is -1.
+ */
+ @Override
+ public long getDefaultAsyncSendTimeout() {
+ return defaultAsyncTimeout;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * The default value for this implementation is -1.
+ */
+ @Override
+ public void setAsyncSendTimeout(long timeout) {
+ this.defaultAsyncTimeout = timeout;
+ }
+
+
+ /**
+ * Cleans up the resources still in use by WebSocket sessions created from
+ * this container. This includes closing sessions and cancelling
+ * {@link Future}s associated with blocking read/writes.
+ */
+ public void destroy() {
+ CloseReason cr = new CloseReason(
+ CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown"));
+
+ for (WsSession session : sessions.keySet()) {
+ try {
+ session.close(cr);
+ } catch (IOException ioe) {
+ log.debug(sm.getString(
+ "wsWebSocketContainer.sessionCloseFail", session.getId()), ioe);
+ }
+ }
+
+ // Only unregister with AsyncChannelGroupUtil if this instance
+ // registered with it
+ if (asynchronousChannelGroup != null) {
+ synchronized (asynchronousChannelGroupLock) {
+ if (asynchronousChannelGroup != null) {
+ AsyncChannelGroupUtil.unregister();
+ asynchronousChannelGroup = null;
+ }
+ }
+ }
+ }
+
+
+ private AsynchronousChannelGroup getAsynchronousChannelGroup() {
+ // Use AsyncChannelGroupUtil to share a common group amongst all
+ // WebSocket clients
+ AsynchronousChannelGroup result = asynchronousChannelGroup;
+ if (result == null) {
+ synchronized (asynchronousChannelGroupLock) {
+ if (asynchronousChannelGroup == null) {
+ asynchronousChannelGroup = AsyncChannelGroupUtil.register();
+ }
+ result = asynchronousChannelGroup;
+ }
+ }
+ return result;
+ }
+
+
+ // ----------------------------------------------- BackgroundProcess methods
+
+ @Override
+ public void backgroundProcess() {
+ // This method gets called once a second.
+ backgroundProcessCount ++;
+ if (backgroundProcessCount >= processPeriod) {
+ backgroundProcessCount = 0;
+
+ for (WsSession wsSession : sessions.keySet()) {
+ wsSession.checkExpiration();
+ }
+ }
+
+ }
+
+
+ @Override
+ public void setProcessPeriod(int period) {
+ this.processPeriod = period;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * The default value is 10 which means session expirations are processed
+ * every 10 seconds.
+ */
+ @Override
+ public int getProcessPeriod() {
+ return processPeriod;
+ }
+
+
+ private static class HttpResponse {
+ private final int status;
+ private final HandshakeResponse handshakeResponse;
+
+ public HttpResponse(int status, HandshakeResponse handshakeResponse) {
+ this.status = status;
+ this.handshakeResponse = handshakeResponse;
+ }
+
+
+ public int getStatus() {
+ return status;
+ }
+
+
+ public HandshakeResponse getHandshakeResponse() {
+ return handshakeResponse;
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/Constants.java b/src/java/nginx/unit/websocket/pojo/Constants.java
new file mode 100644
index 00000000..93cdecc7
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/Constants.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+/**
+ * Internal implementation constants.
+ */
+public class Constants {
+
+ public static final String POJO_PATH_PARAM_KEY =
+ "nginx.unit.websocket.pojo.PojoEndpoint.pathParams";
+ public static final String POJO_METHOD_MAPPING_KEY =
+ "nginx.unit.websocket.pojo.PojoEndpoint.methodMapping";
+
+ private Constants() {
+ // Hide default constructor
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/LocalStrings.properties b/src/java/nginx/unit/websocket/pojo/LocalStrings.properties
new file mode 100644
index 00000000..00ab7e6b
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/LocalStrings.properties
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+pojoEndpointBase.closeSessionFail=Failed to close WebSocket session during error handling
+pojoEndpointBase.onCloseFail=Failed to call onClose method of POJO end point for POJO of type [{0}]
+pojoEndpointBase.onError=No error handling configured for [{0}] and the following error occurred
+pojoEndpointBase.onErrorFail=Failed to call onError method of POJO end point for POJO of type [{0}]
+pojoEndpointBase.onOpenFail=Failed to call onOpen method of POJO end point for POJO of type [{0}]
+pojoEndpointServer.getPojoInstanceFail=Failed to create instance of POJO of type [{0}]
+pojoMethodMapping.decodePathParamFail=Failed to decode path parameter value [{0}] to expected type [{1}]
+pojoMethodMapping.duplicateAnnotation=Duplicate annotations [{0}] present on class [{1}]
+pojoMethodMapping.duplicateLastParam=Multiple boolean (last) parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.duplicateMessageParam=Multiple message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.duplicatePongMessageParam=Multiple PongMessage parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.duplicateSessionParam=Multiple session parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.invalidDecoder=The specified decoder of type [{0}] could not be instantiated
+pojoMethodMapping.invalidPathParamType=Parameters annotated with @PathParam may only be Strings, Java primitives or a boxed version thereof
+pojoMethodMapping.methodNotPublic=The annotated method [{0}] is not public
+pojoMethodMapping.noPayload=No payload parameter present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.onErrorNoThrowable=No Throwable parameter was present on the method [{0}] of class [{1}] that was annotated with OnError
+pojoMethodMapping.paramWithoutAnnotation=A parameter of type [{0}] was found on method[{1}] of class [{2}] that did not have a @PathParam annotation
+pojoMethodMapping.partialInputStream=Invalid InputStream and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.partialObject=Invalid Object and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.partialPong=Invalid PongMessage and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.partialReader=Invalid Reader and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMethodMapping.pongWithPayload=Invalid PongMessage and Message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage
+pojoMessageHandlerWhole.decodeIoFail=IO error while decoding message
+pojoMessageHandlerWhole.maxBufferSize=The maximum supported message size for this implementation is Integer.MAX_VALUE
diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java
new file mode 100644
index 00000000..be679a35
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.util.Map;
+import java.util.Set;
+
+import javax.websocket.CloseReason;
+import javax.websocket.Endpoint;
+import javax.websocket.EndpointConfig;
+import javax.websocket.MessageHandler;
+import javax.websocket.Session;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.util.ExceptionUtils;
+import org.apache.tomcat.util.res.StringManager;
+
+/**
+ * Base implementation (client and server have different concrete
+ * implementations) of the wrapper that converts a POJO instance into a
+ * WebSocket endpoint instance.
+ */
+public abstract class PojoEndpointBase extends Endpoint {
+
+ private final Log log = LogFactory.getLog(PojoEndpointBase.class); // must not be static
+ private static final StringManager sm = StringManager.getManager(PojoEndpointBase.class);
+
+ private Object pojo;
+ private Map<String,String> pathParameters;
+ private PojoMethodMapping methodMapping;
+
+
+ protected final void doOnOpen(Session session, EndpointConfig config) {
+ PojoMethodMapping methodMapping = getMethodMapping();
+ Object pojo = getPojo();
+ Map<String,String> pathParameters = getPathParameters();
+
+ // Add message handlers before calling onOpen since that may trigger a
+ // message which in turn could trigger a response and/or close the
+ // session
+ for (MessageHandler mh : methodMapping.getMessageHandlers(pojo,
+ pathParameters, session, config)) {
+ session.addMessageHandler(mh);
+ }
+
+ if (methodMapping.getOnOpen() != null) {
+ try {
+ methodMapping.getOnOpen().invoke(pojo,
+ methodMapping.getOnOpenArgs(
+ pathParameters, session, config));
+
+ } catch (IllegalAccessException e) {
+ // Reflection related problems
+ log.error(sm.getString(
+ "pojoEndpointBase.onOpenFail",
+ pojo.getClass().getName()), e);
+ handleOnOpenOrCloseError(session, e);
+ } catch (InvocationTargetException e) {
+ Throwable cause = e.getCause();
+ handleOnOpenOrCloseError(session, cause);
+ } catch (Throwable t) {
+ handleOnOpenOrCloseError(session, t);
+ }
+ }
+ }
+
+
+ private void handleOnOpenOrCloseError(Session session, Throwable t) {
+ // If really fatal - re-throw
+ ExceptionUtils.handleThrowable(t);
+
+ // Trigger the error handler and close the session
+ onError(session, t);
+ try {
+ session.close();
+ } catch (IOException ioe) {
+ log.warn(sm.getString("pojoEndpointBase.closeSessionFail"), ioe);
+ }
+ }
+
+ @Override
+ public final void onClose(Session session, CloseReason closeReason) {
+
+ if (methodMapping.getOnClose() != null) {
+ try {
+ methodMapping.getOnClose().invoke(pojo,
+ methodMapping.getOnCloseArgs(pathParameters, session, closeReason));
+ } catch (Throwable t) {
+ log.error(sm.getString("pojoEndpointBase.onCloseFail",
+ pojo.getClass().getName()), t);
+ handleOnOpenOrCloseError(session, t);
+ }
+ }
+
+ // Trigger the destroy method for any associated decoders
+ Set<MessageHandler> messageHandlers = session.getMessageHandlers();
+ for (MessageHandler messageHandler : messageHandlers) {
+ if (messageHandler instanceof PojoMessageHandlerWholeBase<?>) {
+ ((PojoMessageHandlerWholeBase<?>) messageHandler).onClose();
+ }
+ }
+ }
+
+
+ @Override
+ public final void onError(Session session, Throwable throwable) {
+
+ if (methodMapping.getOnError() == null) {
+ log.error(sm.getString("pojoEndpointBase.onError",
+ pojo.getClass().getName()), throwable);
+ } else {
+ try {
+ methodMapping.getOnError().invoke(
+ pojo,
+ methodMapping.getOnErrorArgs(pathParameters, session,
+ throwable));
+ } catch (Throwable t) {
+ ExceptionUtils.handleThrowable(t);
+ log.error(sm.getString("pojoEndpointBase.onErrorFail",
+ pojo.getClass().getName()), t);
+ }
+ }
+ }
+
+ protected Object getPojo() { return pojo; }
+ protected void setPojo(Object pojo) { this.pojo = pojo; }
+
+
+ protected Map<String,String> getPathParameters() { return pathParameters; }
+ protected void setPathParameters(Map<String,String> pathParameters) {
+ this.pathParameters = pathParameters;
+ }
+
+
+ protected PojoMethodMapping getMethodMapping() { return methodMapping; }
+ protected void setMethodMapping(PojoMethodMapping methodMapping) {
+ this.methodMapping = methodMapping;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java
new file mode 100644
index 00000000..6e569487
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.util.Collections;
+import java.util.List;
+
+import javax.websocket.Decoder;
+import javax.websocket.DeploymentException;
+import javax.websocket.EndpointConfig;
+import javax.websocket.Session;
+
+
+/**
+ * Wrapper class for instances of POJOs annotated with
+ * {@link javax.websocket.ClientEndpoint} so they appear as standard
+ * {@link javax.websocket.Endpoint} instances.
+ */
+public class PojoEndpointClient extends PojoEndpointBase {
+
+ public PojoEndpointClient(Object pojo,
+ List<Class<? extends Decoder>> decoders) throws DeploymentException {
+ setPojo(pojo);
+ setMethodMapping(
+ new PojoMethodMapping(pojo.getClass(), decoders, null));
+ setPathParameters(Collections.<String,String>emptyMap());
+ }
+
+ @Override
+ public void onOpen(Session session, EndpointConfig config) {
+ doOnOpen(session, config);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java
new file mode 100644
index 00000000..499f8274
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.util.Map;
+
+import javax.websocket.EndpointConfig;
+import javax.websocket.Session;
+import javax.websocket.server.ServerEndpointConfig;
+
+import org.apache.tomcat.util.res.StringManager;
+
+/**
+ * Wrapper class for instances of POJOs annotated with
+ * {@link javax.websocket.server.ServerEndpoint} so they appear as standard
+ * {@link javax.websocket.Endpoint} instances.
+ */
+public class PojoEndpointServer extends PojoEndpointBase {
+
+ private static final StringManager sm =
+ StringManager.getManager(PojoEndpointServer.class);
+
+ @Override
+ public void onOpen(Session session, EndpointConfig endpointConfig) {
+
+ ServerEndpointConfig sec = (ServerEndpointConfig) endpointConfig;
+
+ Object pojo;
+ try {
+ pojo = sec.getConfigurator().getEndpointInstance(
+ sec.getEndpointClass());
+ } catch (InstantiationException e) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoEndpointServer.getPojoInstanceFail",
+ sec.getEndpointClass().getName()), e);
+ }
+ setPojo(pojo);
+
+ @SuppressWarnings("unchecked")
+ Map<String,String> pathParameters =
+ (Map<String, String>) sec.getUserProperties().get(
+ Constants.POJO_PATH_PARAM_KEY);
+ setPathParameters(pathParameters);
+
+ PojoMethodMapping methodMapping =
+ (PojoMethodMapping) sec.getUserProperties().get(
+ Constants.POJO_METHOD_MAPPING_KEY);
+ setMethodMapping(methodMapping);
+
+ doOnOpen(session, endpointConfig);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java
new file mode 100644
index 00000000..b72d719a
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.io.IOException;
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+
+import javax.websocket.EncodeException;
+import javax.websocket.MessageHandler;
+import javax.websocket.RemoteEndpoint;
+import javax.websocket.Session;
+
+import org.apache.tomcat.util.ExceptionUtils;
+import nginx.unit.websocket.WrappedMessageHandler;
+
+/**
+ * Common implementation code for the POJO message handlers.
+ *
+ * @param <T> The type of message to handle
+ */
+public abstract class PojoMessageHandlerBase<T>
+ implements WrappedMessageHandler {
+
+ protected final Object pojo;
+ protected final Method method;
+ protected final Session session;
+ protected final Object[] params;
+ protected final int indexPayload;
+ protected final boolean convert;
+ protected final int indexSession;
+ protected final long maxMessageSize;
+
+ public PojoMessageHandlerBase(Object pojo, Method method,
+ Session session, Object[] params, int indexPayload, boolean convert,
+ int indexSession, long maxMessageSize) {
+ this.pojo = pojo;
+ this.method = method;
+ // TODO: The method should already be accessible here but the following
+ // code seems to be necessary in some as yet not fully understood cases.
+ try {
+ this.method.setAccessible(true);
+ } catch (Exception e) {
+ // It is better to make sure the method is accessible, but
+ // ignore exceptions and hope for the best
+ }
+ this.session = session;
+ this.params = params;
+ this.indexPayload = indexPayload;
+ this.convert = convert;
+ this.indexSession = indexSession;
+ this.maxMessageSize = maxMessageSize;
+ }
+
+
+ protected final void processResult(Object result) {
+ if (result == null) {
+ return;
+ }
+
+ RemoteEndpoint.Basic remoteEndpoint = session.getBasicRemote();
+ try {
+ if (result instanceof String) {
+ remoteEndpoint.sendText((String) result);
+ } else if (result instanceof ByteBuffer) {
+ remoteEndpoint.sendBinary((ByteBuffer) result);
+ } else if (result instanceof byte[]) {
+ remoteEndpoint.sendBinary(ByteBuffer.wrap((byte[]) result));
+ } else {
+ remoteEndpoint.sendObject(result);
+ }
+ } catch (IOException | EncodeException ioe) {
+ throw new IllegalStateException(ioe);
+ }
+ }
+
+
+ /**
+ * Expose the POJO if it is a message handler so the Session is able to
+ * match requests to remove handlers if the original handler has been
+ * wrapped.
+ */
+ @Override
+ public final MessageHandler getWrappedHandler() {
+ if (pojo instanceof MessageHandler) {
+ return (MessageHandler) pojo;
+ } else {
+ return null;
+ }
+ }
+
+
+ @Override
+ public final long getMaxMessageSize() {
+ return maxMessageSize;
+ }
+
+
+ protected final void handlePojoMethodException(Throwable t) {
+ t = ExceptionUtils.unwrapInvocationTargetException(t);
+ ExceptionUtils.handleThrowable(t);
+ if (t instanceof RuntimeException) {
+ throw (RuntimeException) t;
+ } else {
+ throw new RuntimeException(t.getMessage(), t);
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java
new file mode 100644
index 00000000..d6f37724
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+
+import javax.websocket.DecodeException;
+import javax.websocket.MessageHandler;
+import javax.websocket.Session;
+
+import nginx.unit.websocket.WsSession;
+
+/**
+ * Common implementation code for the POJO partial message handlers. All
+ * the real work is done in this class and in the superclass.
+ *
+ * @param <T> The type of message to handle
+ */
+public abstract class PojoMessageHandlerPartialBase<T>
+ extends PojoMessageHandlerBase<T> implements MessageHandler.Partial<T> {
+
+ private final int indexBoolean;
+
+ public PojoMessageHandlerPartialBase(Object pojo, Method method,
+ Session session, Object[] params, int indexPayload,
+ boolean convert, int indexBoolean, int indexSession,
+ long maxMessageSize) {
+ super(pojo, method, session, params, indexPayload, convert,
+ indexSession, maxMessageSize);
+ this.indexBoolean = indexBoolean;
+ }
+
+
+ @Override
+ public final void onMessage(T message, boolean last) {
+ if (params.length == 1 && params[0] instanceof DecodeException) {
+ ((WsSession) session).getLocal().onError(session,
+ (DecodeException) params[0]);
+ return;
+ }
+ Object[] parameters = params.clone();
+ if (indexBoolean != -1) {
+ parameters[indexBoolean] = Boolean.valueOf(last);
+ }
+ if (indexSession != -1) {
+ parameters[indexSession] = session;
+ }
+ if (convert) {
+ parameters[indexPayload] = ((ByteBuffer) message).array();
+ } else {
+ parameters[indexPayload] = message;
+ }
+ Object result = null;
+ try {
+ result = method.invoke(pojo, parameters);
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ handlePojoMethodException(e);
+ }
+ processResult(result);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java
new file mode 100644
index 00000000..1d334017
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+
+import javax.websocket.Session;
+
+/**
+ * ByteBuffer specific concrete implementation for handling partial messages.
+ */
+public class PojoMessageHandlerPartialBinary
+ extends PojoMessageHandlerPartialBase<ByteBuffer> {
+
+ public PojoMessageHandlerPartialBinary(Object pojo, Method method,
+ Session session, Object[] params, int indexPayload, boolean convert,
+ int indexBoolean, int indexSession, long maxMessageSize) {
+ super(pojo, method, session, params, indexPayload, convert, indexBoolean,
+ indexSession, maxMessageSize);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java
new file mode 100644
index 00000000..8f7c1a0d
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.lang.reflect.Method;
+
+import javax.websocket.Session;
+
+/**
+ * Text specific concrete implementation for handling partial messages.
+ */
+public class PojoMessageHandlerPartialText
+ extends PojoMessageHandlerPartialBase<String> {
+
+ public PojoMessageHandlerPartialText(Object pojo, Method method,
+ Session session, Object[] params, int indexPayload, boolean convert,
+ int indexBoolean, int indexSession, long maxMessageSize) {
+ super(pojo, method, session, params, indexPayload, convert, indexBoolean,
+ indexSession, maxMessageSize);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java
new file mode 100644
index 00000000..23333eb7
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+
+import javax.websocket.DecodeException;
+import javax.websocket.MessageHandler;
+import javax.websocket.Session;
+
+import nginx.unit.websocket.WsSession;
+
+/**
+ * Common implementation code for the POJO whole message handlers. All the real
+ * work is done in this class and in the superclass.
+ *
+ * @param <T> The type of message to handle
+ */
+public abstract class PojoMessageHandlerWholeBase<T>
+ extends PojoMessageHandlerBase<T> implements MessageHandler.Whole<T> {
+
+ public PojoMessageHandlerWholeBase(Object pojo, Method method,
+ Session session, Object[] params, int indexPayload,
+ boolean convert, int indexSession, long maxMessageSize) {
+ super(pojo, method, session, params, indexPayload, convert,
+ indexSession, maxMessageSize);
+ }
+
+
+ @Override
+ public final void onMessage(T message) {
+
+ if (params.length == 1 && params[0] instanceof DecodeException) {
+ ((WsSession) session).getLocal().onError(session,
+ (DecodeException) params[0]);
+ return;
+ }
+
+ // Can this message be decoded?
+ Object payload;
+ try {
+ payload = decode(message);
+ } catch (DecodeException de) {
+ ((WsSession) session).getLocal().onError(session, de);
+ return;
+ }
+
+ if (payload == null) {
+ // Not decoded. Convert if required.
+ if (convert) {
+ payload = convert(message);
+ } else {
+ payload = message;
+ }
+ }
+
+ Object[] parameters = params.clone();
+ if (indexSession != -1) {
+ parameters[indexSession] = session;
+ }
+ parameters[indexPayload] = payload;
+
+ Object result = null;
+ try {
+ result = method.invoke(pojo, parameters);
+ } catch (IllegalAccessException | InvocationTargetException e) {
+ handlePojoMethodException(e);
+ }
+ processResult(result);
+ }
+
+ protected Object convert(T message) {
+ return message;
+ }
+
+
+ protected abstract Object decode(T message) throws DecodeException;
+ protected abstract void onClose();
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java
new file mode 100644
index 00000000..07ff0648
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.lang.reflect.Method;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+
+import javax.websocket.DecodeException;
+import javax.websocket.Decoder;
+import javax.websocket.Decoder.Binary;
+import javax.websocket.Decoder.BinaryStream;
+import javax.websocket.EndpointConfig;
+import javax.websocket.Session;
+
+import org.apache.tomcat.util.res.StringManager;
+
+/**
+ * ByteBuffer specific concrete implementation for handling whole messages.
+ */
+public class PojoMessageHandlerWholeBinary
+ extends PojoMessageHandlerWholeBase<ByteBuffer> {
+
+ private static final StringManager sm =
+ StringManager.getManager(PojoMessageHandlerWholeBinary.class);
+
+ private final List<Decoder> decoders = new ArrayList<>();
+
+ private final boolean isForInputStream;
+
+ public PojoMessageHandlerWholeBinary(Object pojo, Method method,
+ Session session, EndpointConfig config,
+ List<Class<? extends Decoder>> decoderClazzes, Object[] params,
+ int indexPayload, boolean convert, int indexSession,
+ boolean isForInputStream, long maxMessageSize) {
+ super(pojo, method, session, params, indexPayload, convert,
+ indexSession, maxMessageSize);
+
+ // Update binary text size handled by session
+ if (maxMessageSize > -1 && maxMessageSize > session.getMaxBinaryMessageBufferSize()) {
+ if (maxMessageSize > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMessageHandlerWhole.maxBufferSize"));
+ }
+ session.setMaxBinaryMessageBufferSize((int) maxMessageSize);
+ }
+
+ try {
+ if (decoderClazzes != null) {
+ for (Class<? extends Decoder> decoderClazz : decoderClazzes) {
+ if (Binary.class.isAssignableFrom(decoderClazz)) {
+ Binary<?> decoder = (Binary<?>) decoderClazz.getConstructor().newInstance();
+ decoder.init(config);
+ decoders.add(decoder);
+ } else if (BinaryStream.class.isAssignableFrom(
+ decoderClazz)) {
+ BinaryStream<?> decoder = (BinaryStream<?>)
+ decoderClazz.getConstructor().newInstance();
+ decoder.init(config);
+ decoders.add(decoder);
+ } else {
+ // Text decoder - ignore it
+ }
+ }
+ }
+ } catch (ReflectiveOperationException e) {
+ throw new IllegalArgumentException(e);
+ }
+ this.isForInputStream = isForInputStream;
+ }
+
+
+ @Override
+ protected Object decode(ByteBuffer message) throws DecodeException {
+ for (Decoder decoder : decoders) {
+ if (decoder instanceof Binary) {
+ if (((Binary<?>) decoder).willDecode(message)) {
+ return ((Binary<?>) decoder).decode(message);
+ }
+ } else {
+ byte[] array = new byte[message.limit() - message.position()];
+ message.get(array);
+ ByteArrayInputStream bais = new ByteArrayInputStream(array);
+ try {
+ return ((BinaryStream<?>) decoder).decode(bais);
+ } catch (IOException ioe) {
+ throw new DecodeException(message, sm.getString(
+ "pojoMessageHandlerWhole.decodeIoFail"), ioe);
+ }
+ }
+ }
+ return null;
+ }
+
+
+ @Override
+ protected Object convert(ByteBuffer message) {
+ byte[] array = new byte[message.remaining()];
+ message.get(array);
+ if (isForInputStream) {
+ return new ByteArrayInputStream(array);
+ } else {
+ return array;
+ }
+ }
+
+
+ @Override
+ protected void onClose() {
+ for (Decoder decoder : decoders) {
+ decoder.destroy();
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java
new file mode 100644
index 00000000..bdedd7de
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.lang.reflect.Method;
+
+import javax.websocket.PongMessage;
+import javax.websocket.Session;
+
+/**
+ * PongMessage specific concrete implementation for handling whole messages.
+ */
+public class PojoMessageHandlerWholePong
+ extends PojoMessageHandlerWholeBase<PongMessage> {
+
+ public PojoMessageHandlerWholePong(Object pojo, Method method,
+ Session session, Object[] params, int indexPayload, boolean convert,
+ int indexSession) {
+ super(pojo, method, session, params, indexPayload, convert,
+ indexSession, -1);
+ }
+
+ @Override
+ protected Object decode(PongMessage message) {
+ // Never decoded
+ return null;
+ }
+
+
+ @Override
+ protected void onClose() {
+ // NO-OP
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java
new file mode 100644
index 00000000..59007349
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.io.IOException;
+import java.io.StringReader;
+import java.lang.reflect.Method;
+import java.util.ArrayList;
+import java.util.List;
+
+import javax.websocket.DecodeException;
+import javax.websocket.Decoder;
+import javax.websocket.Decoder.Text;
+import javax.websocket.Decoder.TextStream;
+import javax.websocket.EndpointConfig;
+import javax.websocket.Session;
+
+import org.apache.tomcat.util.res.StringManager;
+import nginx.unit.websocket.Util;
+
+
+/**
+ * Text specific concrete implementation for handling whole messages.
+ */
+public class PojoMessageHandlerWholeText
+ extends PojoMessageHandlerWholeBase<String> {
+
+ private static final StringManager sm =
+ StringManager.getManager(PojoMessageHandlerWholeText.class);
+
+ private final List<Decoder> decoders = new ArrayList<>();
+ private final Class<?> primitiveType;
+
+ public PojoMessageHandlerWholeText(Object pojo, Method method,
+ Session session, EndpointConfig config,
+ List<Class<? extends Decoder>> decoderClazzes, Object[] params,
+ int indexPayload, boolean convert, int indexSession,
+ long maxMessageSize) {
+ super(pojo, method, session, params, indexPayload, convert,
+ indexSession, maxMessageSize);
+
+ // Update max text size handled by session
+ if (maxMessageSize > -1 && maxMessageSize > session.getMaxTextMessageBufferSize()) {
+ if (maxMessageSize > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMessageHandlerWhole.maxBufferSize"));
+ }
+ session.setMaxTextMessageBufferSize((int) maxMessageSize);
+ }
+
+ // Check for primitives
+ Class<?> type = method.getParameterTypes()[indexPayload];
+ if (Util.isPrimitive(type)) {
+ primitiveType = type;
+ return;
+ } else {
+ primitiveType = null;
+ }
+
+ try {
+ if (decoderClazzes != null) {
+ for (Class<? extends Decoder> decoderClazz : decoderClazzes) {
+ if (Text.class.isAssignableFrom(decoderClazz)) {
+ Text<?> decoder = (Text<?>) decoderClazz.getConstructor().newInstance();
+ decoder.init(config);
+ decoders.add(decoder);
+ } else if (TextStream.class.isAssignableFrom(
+ decoderClazz)) {
+ TextStream<?> decoder =
+ (TextStream<?>) decoderClazz.getConstructor().newInstance();
+ decoder.init(config);
+ decoders.add(decoder);
+ } else {
+ // Binary decoder - ignore it
+ }
+ }
+ }
+ } catch (ReflectiveOperationException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+
+ @Override
+ protected Object decode(String message) throws DecodeException {
+ // Handle primitives
+ if (primitiveType != null) {
+ return Util.coerceToType(primitiveType, message);
+ }
+ // Handle full decoders
+ for (Decoder decoder : decoders) {
+ if (decoder instanceof Text) {
+ if (((Text<?>) decoder).willDecode(message)) {
+ return ((Text<?>) decoder).decode(message);
+ }
+ } else {
+ StringReader r = new StringReader(message);
+ try {
+ return ((TextStream<?>) decoder).decode(r);
+ } catch (IOException ioe) {
+ throw new DecodeException(message, sm.getString(
+ "pojoMessageHandlerWhole.decodeIoFail"), ioe);
+ }
+ }
+ }
+ return null;
+ }
+
+
+ @Override
+ protected Object convert(String message) {
+ return new StringReader(message);
+ }
+
+
+ @Override
+ protected void onClose() {
+ for (Decoder decoder : decoders) {
+ decoder.destroy();
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java
new file mode 100644
index 00000000..2385b5c7
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java
@@ -0,0 +1,731 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+import java.io.InputStream;
+import java.io.Reader;
+import java.lang.annotation.Annotation;
+import java.lang.reflect.Method;
+import java.lang.reflect.Modifier;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import javax.websocket.CloseReason;
+import javax.websocket.DecodeException;
+import javax.websocket.Decoder;
+import javax.websocket.DeploymentException;
+import javax.websocket.EndpointConfig;
+import javax.websocket.MessageHandler;
+import javax.websocket.OnClose;
+import javax.websocket.OnError;
+import javax.websocket.OnMessage;
+import javax.websocket.OnOpen;
+import javax.websocket.PongMessage;
+import javax.websocket.Session;
+import javax.websocket.server.PathParam;
+
+import org.apache.tomcat.util.res.StringManager;
+import nginx.unit.websocket.DecoderEntry;
+import nginx.unit.websocket.Util;
+import nginx.unit.websocket.Util.DecoderMatch;
+
+/**
+ * For a POJO class annotated with
+ * {@link javax.websocket.server.ServerEndpoint}, an instance of this class
+ * creates and caches the method handler, method information and parameter
+ * information for the onXXX calls.
+ */
+public class PojoMethodMapping {
+
+ private static final StringManager sm =
+ StringManager.getManager(PojoMethodMapping.class);
+
+ private final Method onOpen;
+ private final Method onClose;
+ private final Method onError;
+ private final PojoPathParam[] onOpenParams;
+ private final PojoPathParam[] onCloseParams;
+ private final PojoPathParam[] onErrorParams;
+ private final List<MessageHandlerInfo> onMessage = new ArrayList<>();
+ private final String wsPath;
+
+
+ public PojoMethodMapping(Class<?> clazzPojo,
+ List<Class<? extends Decoder>> decoderClazzes, String wsPath)
+ throws DeploymentException {
+
+ this.wsPath = wsPath;
+
+ List<DecoderEntry> decoders = Util.getDecoders(decoderClazzes);
+ Method open = null;
+ Method close = null;
+ Method error = null;
+ Method[] clazzPojoMethods = null;
+ Class<?> currentClazz = clazzPojo;
+ while (!currentClazz.equals(Object.class)) {
+ Method[] currentClazzMethods = currentClazz.getDeclaredMethods();
+ if (currentClazz == clazzPojo) {
+ clazzPojoMethods = currentClazzMethods;
+ }
+ for (Method method : currentClazzMethods) {
+ if (method.getAnnotation(OnOpen.class) != null) {
+ checkPublic(method);
+ if (open == null) {
+ open = method;
+ } else {
+ if (currentClazz == clazzPojo ||
+ !isMethodOverride(open, method)) {
+ // Duplicate annotation
+ throw new DeploymentException(sm.getString(
+ "pojoMethodMapping.duplicateAnnotation",
+ OnOpen.class, currentClazz));
+ }
+ }
+ } else if (method.getAnnotation(OnClose.class) != null) {
+ checkPublic(method);
+ if (close == null) {
+ close = method;
+ } else {
+ if (currentClazz == clazzPojo ||
+ !isMethodOverride(close, method)) {
+ // Duplicate annotation
+ throw new DeploymentException(sm.getString(
+ "pojoMethodMapping.duplicateAnnotation",
+ OnClose.class, currentClazz));
+ }
+ }
+ } else if (method.getAnnotation(OnError.class) != null) {
+ checkPublic(method);
+ if (error == null) {
+ error = method;
+ } else {
+ if (currentClazz == clazzPojo ||
+ !isMethodOverride(error, method)) {
+ // Duplicate annotation
+ throw new DeploymentException(sm.getString(
+ "pojoMethodMapping.duplicateAnnotation",
+ OnError.class, currentClazz));
+ }
+ }
+ } else if (method.getAnnotation(OnMessage.class) != null) {
+ checkPublic(method);
+ MessageHandlerInfo messageHandler = new MessageHandlerInfo(method, decoders);
+ boolean found = false;
+ for (MessageHandlerInfo otherMessageHandler : onMessage) {
+ if (messageHandler.targetsSameWebSocketMessageType(otherMessageHandler)) {
+ found = true;
+ if (currentClazz == clazzPojo ||
+ !isMethodOverride(messageHandler.m, otherMessageHandler.m)) {
+ // Duplicate annotation
+ throw new DeploymentException(sm.getString(
+ "pojoMethodMapping.duplicateAnnotation",
+ OnMessage.class, currentClazz));
+ }
+ }
+ }
+ if (!found) {
+ onMessage.add(messageHandler);
+ }
+ } else {
+ // Method not annotated
+ }
+ }
+ currentClazz = currentClazz.getSuperclass();
+ }
+ // If the methods are not on clazzPojo and they are overridden
+ // by a non annotated method in clazzPojo, they should be ignored
+ if (open != null && open.getDeclaringClass() != clazzPojo) {
+ if (isOverridenWithoutAnnotation(clazzPojoMethods, open, OnOpen.class)) {
+ open = null;
+ }
+ }
+ if (close != null && close.getDeclaringClass() != clazzPojo) {
+ if (isOverridenWithoutAnnotation(clazzPojoMethods, close, OnClose.class)) {
+ close = null;
+ }
+ }
+ if (error != null && error.getDeclaringClass() != clazzPojo) {
+ if (isOverridenWithoutAnnotation(clazzPojoMethods, error, OnError.class)) {
+ error = null;
+ }
+ }
+ List<MessageHandlerInfo> overriddenOnMessage = new ArrayList<>();
+ for (MessageHandlerInfo messageHandler : onMessage) {
+ if (messageHandler.m.getDeclaringClass() != clazzPojo
+ && isOverridenWithoutAnnotation(clazzPojoMethods, messageHandler.m, OnMessage.class)) {
+ overriddenOnMessage.add(messageHandler);
+ }
+ }
+ for (MessageHandlerInfo messageHandler : overriddenOnMessage) {
+ onMessage.remove(messageHandler);
+ }
+ this.onOpen = open;
+ this.onClose = close;
+ this.onError = error;
+ onOpenParams = getPathParams(onOpen, MethodType.ON_OPEN);
+ onCloseParams = getPathParams(onClose, MethodType.ON_CLOSE);
+ onErrorParams = getPathParams(onError, MethodType.ON_ERROR);
+ }
+
+
+ private void checkPublic(Method m) throws DeploymentException {
+ if (!Modifier.isPublic(m.getModifiers())) {
+ throw new DeploymentException(sm.getString(
+ "pojoMethodMapping.methodNotPublic", m.getName()));
+ }
+ }
+
+
+ private boolean isMethodOverride(Method method1, Method method2) {
+ return method1.getName().equals(method2.getName())
+ && method1.getReturnType().equals(method2.getReturnType())
+ && Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes());
+ }
+
+
+ private boolean isOverridenWithoutAnnotation(Method[] methods,
+ Method superclazzMethod, Class<? extends Annotation> annotation) {
+ for (Method method : methods) {
+ if (isMethodOverride(method, superclazzMethod)
+ && (method.getAnnotation(annotation) == null)) {
+ return true;
+ }
+ }
+ return false;
+ }
+
+
+ public String getWsPath() {
+ return wsPath;
+ }
+
+
+ public Method getOnOpen() {
+ return onOpen;
+ }
+
+
+ public Object[] getOnOpenArgs(Map<String,String> pathParameters,
+ Session session, EndpointConfig config) throws DecodeException {
+ return buildArgs(onOpenParams, pathParameters, session, config, null,
+ null);
+ }
+
+
+ public Method getOnClose() {
+ return onClose;
+ }
+
+
+ public Object[] getOnCloseArgs(Map<String,String> pathParameters,
+ Session session, CloseReason closeReason) throws DecodeException {
+ return buildArgs(onCloseParams, pathParameters, session, null, null,
+ closeReason);
+ }
+
+
+ public Method getOnError() {
+ return onError;
+ }
+
+
+ public Object[] getOnErrorArgs(Map<String,String> pathParameters,
+ Session session, Throwable throwable) throws DecodeException {
+ return buildArgs(onErrorParams, pathParameters, session, null,
+ throwable, null);
+ }
+
+
+ public boolean hasMessageHandlers() {
+ return !onMessage.isEmpty();
+ }
+
+
+ public Set<MessageHandler> getMessageHandlers(Object pojo,
+ Map<String,String> pathParameters, Session session,
+ EndpointConfig config) {
+ Set<MessageHandler> result = new HashSet<>();
+ for (MessageHandlerInfo messageMethod : onMessage) {
+ result.addAll(messageMethod.getMessageHandlers(pojo, pathParameters,
+ session, config));
+ }
+ return result;
+ }
+
+
+ private static PojoPathParam[] getPathParams(Method m,
+ MethodType methodType) throws DeploymentException {
+ if (m == null) {
+ return new PojoPathParam[0];
+ }
+ boolean foundThrowable = false;
+ Class<?>[] types = m.getParameterTypes();
+ Annotation[][] paramsAnnotations = m.getParameterAnnotations();
+ PojoPathParam[] result = new PojoPathParam[types.length];
+ for (int i = 0; i < types.length; i++) {
+ Class<?> type = types[i];
+ if (type.equals(Session.class)) {
+ result[i] = new PojoPathParam(type, null);
+ } else if (methodType == MethodType.ON_OPEN &&
+ type.equals(EndpointConfig.class)) {
+ result[i] = new PojoPathParam(type, null);
+ } else if (methodType == MethodType.ON_ERROR
+ && type.equals(Throwable.class)) {
+ foundThrowable = true;
+ result[i] = new PojoPathParam(type, null);
+ } else if (methodType == MethodType.ON_CLOSE &&
+ type.equals(CloseReason.class)) {
+ result[i] = new PojoPathParam(type, null);
+ } else {
+ Annotation[] paramAnnotations = paramsAnnotations[i];
+ for (Annotation paramAnnotation : paramAnnotations) {
+ if (paramAnnotation.annotationType().equals(
+ PathParam.class)) {
+ // Check that the type is valid. "0" coerces to every
+ // valid type
+ try {
+ Util.coerceToType(type, "0");
+ } catch (IllegalArgumentException iae) {
+ throw new DeploymentException(sm.getString(
+ "pojoMethodMapping.invalidPathParamType"),
+ iae);
+ }
+ result[i] = new PojoPathParam(type,
+ ((PathParam) paramAnnotation).value());
+ break;
+ }
+ }
+ // Parameters without annotations are not permitted
+ if (result[i] == null) {
+ throw new DeploymentException(sm.getString(
+ "pojoMethodMapping.paramWithoutAnnotation",
+ type, m.getName(), m.getClass().getName()));
+ }
+ }
+ }
+ if (methodType == MethodType.ON_ERROR && !foundThrowable) {
+ throw new DeploymentException(sm.getString(
+ "pojoMethodMapping.onErrorNoThrowable",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ return result;
+ }
+
+
+ private static Object[] buildArgs(PojoPathParam[] pathParams,
+ Map<String,String> pathParameters, Session session,
+ EndpointConfig config, Throwable throwable, CloseReason closeReason)
+ throws DecodeException {
+ Object[] result = new Object[pathParams.length];
+ for (int i = 0; i < pathParams.length; i++) {
+ Class<?> type = pathParams[i].getType();
+ if (type.equals(Session.class)) {
+ result[i] = session;
+ } else if (type.equals(EndpointConfig.class)) {
+ result[i] = config;
+ } else if (type.equals(Throwable.class)) {
+ result[i] = throwable;
+ } else if (type.equals(CloseReason.class)) {
+ result[i] = closeReason;
+ } else {
+ String name = pathParams[i].getName();
+ String value = pathParameters.get(name);
+ try {
+ result[i] = Util.coerceToType(type, value);
+ } catch (Exception e) {
+ throw new DecodeException(value, sm.getString(
+ "pojoMethodMapping.decodePathParamFail",
+ value, type), e);
+ }
+ }
+ }
+ return result;
+ }
+
+
+ private static class MessageHandlerInfo {
+
+ private final Method m;
+ private int indexString = -1;
+ private int indexByteArray = -1;
+ private int indexByteBuffer = -1;
+ private int indexPong = -1;
+ private int indexBoolean = -1;
+ private int indexSession = -1;
+ private int indexInputStream = -1;
+ private int indexReader = -1;
+ private int indexPrimitive = -1;
+ private Class<?> primitiveType = null;
+ private Map<Integer,PojoPathParam> indexPathParams = new HashMap<>();
+ private int indexPayload = -1;
+ private DecoderMatch decoderMatch = null;
+ private long maxMessageSize = -1;
+
+ public MessageHandlerInfo(Method m, List<DecoderEntry> decoderEntries) {
+ this.m = m;
+
+ Class<?>[] types = m.getParameterTypes();
+ Annotation[][] paramsAnnotations = m.getParameterAnnotations();
+
+ for (int i = 0; i < types.length; i++) {
+ boolean paramFound = false;
+ Annotation[] paramAnnotations = paramsAnnotations[i];
+ for (Annotation paramAnnotation : paramAnnotations) {
+ if (paramAnnotation.annotationType().equals(
+ PathParam.class)) {
+ indexPathParams.put(
+ Integer.valueOf(i), new PojoPathParam(types[i],
+ ((PathParam) paramAnnotation).value()));
+ paramFound = true;
+ break;
+ }
+ }
+ if (paramFound) {
+ continue;
+ }
+ if (String.class.isAssignableFrom(types[i])) {
+ if (indexString == -1) {
+ indexString = i;
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else if (Reader.class.isAssignableFrom(types[i])) {
+ if (indexReader == -1) {
+ indexReader = i;
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else if (boolean.class == types[i]) {
+ if (indexBoolean == -1) {
+ indexBoolean = i;
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateLastParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else if (ByteBuffer.class.isAssignableFrom(types[i])) {
+ if (indexByteBuffer == -1) {
+ indexByteBuffer = i;
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else if (byte[].class == types[i]) {
+ if (indexByteArray == -1) {
+ indexByteArray = i;
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else if (InputStream.class.isAssignableFrom(types[i])) {
+ if (indexInputStream == -1) {
+ indexInputStream = i;
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else if (Util.isPrimitive(types[i])) {
+ if (indexPrimitive == -1) {
+ indexPrimitive = i;
+ primitiveType = types[i];
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else if (Session.class.isAssignableFrom(types[i])) {
+ if (indexSession == -1) {
+ indexSession = i;
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateSessionParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else if (PongMessage.class.isAssignableFrom(types[i])) {
+ if (indexPong == -1) {
+ indexPong = i;
+ } else {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicatePongMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ } else {
+ if (decoderMatch != null && decoderMatch.hasMatches()) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ decoderMatch = new DecoderMatch(types[i], decoderEntries);
+
+ if (decoderMatch.hasMatches()) {
+ indexPayload = i;
+ }
+ }
+ }
+
+ // Additional checks required
+ if (indexString != -1) {
+ if (indexPayload != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ } else {
+ indexPayload = indexString;
+ }
+ }
+ if (indexReader != -1) {
+ if (indexPayload != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ } else {
+ indexPayload = indexReader;
+ }
+ }
+ if (indexByteArray != -1) {
+ if (indexPayload != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ } else {
+ indexPayload = indexByteArray;
+ }
+ }
+ if (indexByteBuffer != -1) {
+ if (indexPayload != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ } else {
+ indexPayload = indexByteBuffer;
+ }
+ }
+ if (indexInputStream != -1) {
+ if (indexPayload != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ } else {
+ indexPayload = indexInputStream;
+ }
+ }
+ if (indexPrimitive != -1) {
+ if (indexPayload != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.duplicateMessageParam",
+ m.getName(), m.getDeclaringClass().getName()));
+ } else {
+ indexPayload = indexPrimitive;
+ }
+ }
+ if (indexPong != -1) {
+ if (indexPayload != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.pongWithPayload",
+ m.getName(), m.getDeclaringClass().getName()));
+ } else {
+ indexPayload = indexPong;
+ }
+ }
+ if (indexPayload == -1 && indexPrimitive == -1 &&
+ indexBoolean != -1) {
+ // The boolean we found is a payload, not a last flag
+ indexPayload = indexBoolean;
+ indexPrimitive = indexBoolean;
+ primitiveType = Boolean.TYPE;
+ indexBoolean = -1;
+ }
+ if (indexPayload == -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.noPayload",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ if (indexPong != -1 && indexBoolean != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.partialPong",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ if(indexReader != -1 && indexBoolean != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.partialReader",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ if(indexInputStream != -1 && indexBoolean != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.partialInputStream",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+ if (decoderMatch != null && decoderMatch.hasMatches() &&
+ indexBoolean != -1) {
+ throw new IllegalArgumentException(sm.getString(
+ "pojoMethodMapping.partialObject",
+ m.getName(), m.getDeclaringClass().getName()));
+ }
+
+ maxMessageSize = m.getAnnotation(OnMessage.class).maxMessageSize();
+ }
+
+
+ public boolean targetsSameWebSocketMessageType(MessageHandlerInfo otherHandler) {
+ if (otherHandler == null) {
+ return false;
+ }
+ if (indexByteArray >= 0 && otherHandler.indexByteArray >= 0) {
+ return true;
+ }
+ if (indexByteBuffer >= 0 && otherHandler.indexByteBuffer >= 0) {
+ return true;
+ }
+ if (indexInputStream >= 0 && otherHandler.indexInputStream >= 0) {
+ return true;
+ }
+ if (indexPong >= 0 && otherHandler.indexPong >= 0) {
+ return true;
+ }
+ if (indexPrimitive >= 0 && otherHandler.indexPrimitive >= 0
+ && primitiveType == otherHandler.primitiveType) {
+ return true;
+ }
+ if (indexReader >= 0 && otherHandler.indexReader >= 0) {
+ return true;
+ }
+ if (indexString >= 0 && otherHandler.indexString >= 0) {
+ return true;
+ }
+ if (decoderMatch != null && otherHandler.decoderMatch != null
+ && decoderMatch.getTarget().equals(otherHandler.decoderMatch.getTarget())) {
+ return true;
+ }
+ return false;
+ }
+
+
+ public Set<MessageHandler> getMessageHandlers(Object pojo,
+ Map<String,String> pathParameters, Session session,
+ EndpointConfig config) {
+ Object[] params = new Object[m.getParameterTypes().length];
+
+ for (Map.Entry<Integer,PojoPathParam> entry :
+ indexPathParams.entrySet()) {
+ PojoPathParam pathParam = entry.getValue();
+ String valueString = pathParameters.get(pathParam.getName());
+ Object value = null;
+ try {
+ value = Util.coerceToType(pathParam.getType(), valueString);
+ } catch (Exception e) {
+ DecodeException de = new DecodeException(valueString,
+ sm.getString(
+ "pojoMethodMapping.decodePathParamFail",
+ valueString, pathParam.getType()), e);
+ params = new Object[] { de };
+ break;
+ }
+ params[entry.getKey().intValue()] = value;
+ }
+
+ Set<MessageHandler> results = new HashSet<>(2);
+ if (indexBoolean == -1) {
+ // Basic
+ if (indexString != -1 || indexPrimitive != -1) {
+ MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
+ session, config, null, params, indexPayload, false,
+ indexSession, maxMessageSize);
+ results.add(mh);
+ } else if (indexReader != -1) {
+ MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m,
+ session, config, null, params, indexReader, true,
+ indexSession, maxMessageSize);
+ results.add(mh);
+ } else if (indexByteArray != -1) {
+ MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
+ m, session, config, null, params, indexByteArray,
+ true, indexSession, false, maxMessageSize);
+ results.add(mh);
+ } else if (indexByteBuffer != -1) {
+ MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
+ m, session, config, null, params, indexByteBuffer,
+ false, indexSession, false, maxMessageSize);
+ results.add(mh);
+ } else if (indexInputStream != -1) {
+ MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo,
+ m, session, config, null, params, indexInputStream,
+ true, indexSession, true, maxMessageSize);
+ results.add(mh);
+ } else if (decoderMatch != null && decoderMatch.hasMatches()) {
+ if (decoderMatch.getBinaryDecoders().size() > 0) {
+ MessageHandler mh = new PojoMessageHandlerWholeBinary(
+ pojo, m, session, config,
+ decoderMatch.getBinaryDecoders(), params,
+ indexPayload, true, indexSession, true,
+ maxMessageSize);
+ results.add(mh);
+ }
+ if (decoderMatch.getTextDecoders().size() > 0) {
+ MessageHandler mh = new PojoMessageHandlerWholeText(
+ pojo, m, session, config,
+ decoderMatch.getTextDecoders(), params,
+ indexPayload, true, indexSession, maxMessageSize);
+ results.add(mh);
+ }
+ } else {
+ MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m,
+ session, params, indexPong, false, indexSession);
+ results.add(mh);
+ }
+ } else {
+ // ASync
+ if (indexString != -1) {
+ MessageHandler mh = new PojoMessageHandlerPartialText(pojo,
+ m, session, params, indexString, false,
+ indexBoolean, indexSession, maxMessageSize);
+ results.add(mh);
+ } else if (indexByteArray != -1) {
+ MessageHandler mh = new PojoMessageHandlerPartialBinary(
+ pojo, m, session, params, indexByteArray, true,
+ indexBoolean, indexSession, maxMessageSize);
+ results.add(mh);
+ } else {
+ MessageHandler mh = new PojoMessageHandlerPartialBinary(
+ pojo, m, session, params, indexByteBuffer, false,
+ indexBoolean, indexSession, maxMessageSize);
+ results.add(mh);
+ }
+ }
+ return results;
+ }
+ }
+
+
+ private enum MethodType {
+ ON_OPEN,
+ ON_CLOSE,
+ ON_ERROR
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/PojoPathParam.java b/src/java/nginx/unit/websocket/pojo/PojoPathParam.java
new file mode 100644
index 00000000..859b6d68
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/PojoPathParam.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.pojo;
+
+/**
+ * Stores the parameter type and name for a parameter that needs to be passed to
+ * an onXxx method of {@link javax.websocket.Endpoint}. The name is only present
+ * for parameters annotated with
+ * {@link javax.websocket.server.PathParam}. For the
+ * {@link javax.websocket.Session} and {@link java.lang.Throwable} parameters,
+ * {@link #getName()} will always return <code>null</code>.
+ */
+public class PojoPathParam {
+
+ private final Class<?> type;
+ private final String name;
+
+
+ public PojoPathParam(Class<?> type, String name) {
+ this.type = type;
+ this.name = name;
+ }
+
+
+ public Class<?> getType() {
+ return type;
+ }
+
+
+ public String getName() {
+ return name;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/pojo/package-info.java b/src/java/nginx/unit/websocket/pojo/package-info.java
new file mode 100644
index 00000000..39cf80c8
--- /dev/null
+++ b/src/java/nginx/unit/websocket/pojo/package-info.java
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * This package provides the necessary plumbing to convert an annotated POJO
+ * into a WebSocket {@link javax.websocket.Endpoint}.
+ */
+package nginx.unit.websocket.pojo;
diff --git a/src/java/nginx/unit/websocket/server/Constants.java b/src/java/nginx/unit/websocket/server/Constants.java
new file mode 100644
index 00000000..5210c4ba
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/Constants.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+/**
+ * Internal implementation constants.
+ */
+public class Constants {
+
+ public static final String BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM =
+ "nginx.unit.websocket.binaryBufferSize";
+ public static final String TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM =
+ "nginx.unit.websocket.textBufferSize";
+ public static final String ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM =
+ "nginx.unit.websocket.noAddAfterHandshake";
+
+ public static final String SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE =
+ "javax.websocket.server.ServerContainer";
+
+
+ private Constants() {
+ // Hide default constructor
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java b/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java
new file mode 100644
index 00000000..43ffe2bc
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.util.ArrayList;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+
+import javax.websocket.Extension;
+import javax.websocket.HandshakeResponse;
+import javax.websocket.server.HandshakeRequest;
+import javax.websocket.server.ServerEndpointConfig;
+
+public class DefaultServerEndpointConfigurator
+ extends ServerEndpointConfig.Configurator {
+
+ @Override
+ public <T> T getEndpointInstance(Class<T> clazz)
+ throws InstantiationException {
+ try {
+ return clazz.getConstructor().newInstance();
+ } catch (InstantiationException e) {
+ throw e;
+ } catch (ReflectiveOperationException e) {
+ InstantiationException ie = new InstantiationException();
+ ie.initCause(e);
+ throw ie;
+ }
+ }
+
+
+ @Override
+ public String getNegotiatedSubprotocol(List<String> supported,
+ List<String> requested) {
+
+ for (String request : requested) {
+ if (supported.contains(request)) {
+ return request;
+ }
+ }
+ return "";
+ }
+
+
+ @Override
+ public List<Extension> getNegotiatedExtensions(List<Extension> installed,
+ List<Extension> requested) {
+ Set<String> installedNames = new HashSet<>();
+ for (Extension e : installed) {
+ installedNames.add(e.getName());
+ }
+ List<Extension> result = new ArrayList<>();
+ for (Extension request : requested) {
+ if (installedNames.contains(request.getName())) {
+ result.add(request);
+ }
+ }
+ return result;
+ }
+
+
+ @Override
+ public boolean checkOrigin(String originHeaderValue) {
+ return true;
+ }
+
+ @Override
+ public void modifyHandshake(ServerEndpointConfig sec,
+ HandshakeRequest request, HandshakeResponse response) {
+ // NO-OP
+ }
+
+}
diff --git a/src/java/nginx/unit/websocket/server/LocalStrings.properties b/src/java/nginx/unit/websocket/server/LocalStrings.properties
new file mode 100644
index 00000000..5bc12501
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/LocalStrings.properties
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+serverContainer.addNotAllowed=No further Endpoints may be registered once an attempt has been made to use one of the previously registered endpoints
+serverContainer.configuratorFail=Failed to create configurator of type [{0}] for POJO of type [{1}]
+serverContainer.duplicatePaths=Multiple Endpoints may not be deployed to the same path [{0}] : existing endpoint was [{1}] and new endpoint is [{2}]
+serverContainer.encoderFail=Unable to create encoder of type [{0}]
+serverContainer.endpointDeploy=Endpoint class [{0}] deploying to path [{1}] in ServletContext [{2}]
+serverContainer.missingAnnotation=Cannot deploy POJO class [{0}] as it is not annotated with @ServerEndpoint
+serverContainer.missingEndpoint=An Endpoint instance has been request for path [{0}] but no matching Endpoint class was found
+serverContainer.pojoDeploy=POJO class [{0}] deploying to path [{1}] in ServletContext [{2}]
+serverContainer.servletContextMismatch=Attempted to register a POJO annotated for WebSocket at path [{0}] in the ServletContext with context path [{1}] when the WebSocket ServerContainer is allocated to the ServletContext with context path [{2}]
+serverContainer.servletContextMissing=No ServletContext was specified
+
+upgradeUtil.incompatibleRsv=Extensions were specified that have incompatible RSV bit usage
+
+uriTemplate.duplicateParameter=The parameter [{0}] appears more than once in the path which is not permitted
+uriTemplate.emptySegment=The path [{0}] contains one or more empty segments which are is not permitted
+uriTemplate.invalidPath=The path [{0}] is not valid.
+uriTemplate.invalidSegment=The segment [{0}] is not valid in the provided path [{1}]
+
+wsFrameServer.bytesRead=Read [{0}] bytes into input buffer ready for processing
+wsFrameServer.illegalReadState=Unexpected read state [{0}]
+wsFrameServer.onDataAvailable=Method entry
+
+wsHttpUpgradeHandler.closeOnError=Closing WebSocket connection due to an error
+wsHttpUpgradeHandler.destroyFailed=Failed to close WebConnection while destroying the WebSocket HttpUpgradeHandler
+wsHttpUpgradeHandler.noPreInit=The preInit() method must be called to configure the WebSocket HttpUpgradeHandler before the container calls init(). Usually, this means the Servlet that created the WsHttpUpgradeHandler instance should also call preInit()
+wsHttpUpgradeHandler.serverStop=The server is stopping
+
+wsRemoteEndpointServer.closeFailed=Failed to close the ServletOutputStream connection cleanly
diff --git a/src/java/nginx/unit/websocket/server/UpgradeUtil.java b/src/java/nginx/unit/websocket/server/UpgradeUtil.java
new file mode 100644
index 00000000..162f01c7
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/UpgradeUtil.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.websocket.Endpoint;
+import javax.websocket.Extension;
+import javax.websocket.HandshakeResponse;
+import javax.websocket.server.ServerEndpointConfig;
+
+import nginx.unit.Request;
+
+import org.apache.tomcat.util.codec.binary.Base64;
+import org.apache.tomcat.util.res.StringManager;
+import org.apache.tomcat.util.security.ConcurrentMessageDigest;
+import nginx.unit.websocket.Constants;
+import nginx.unit.websocket.Transformation;
+import nginx.unit.websocket.TransformationFactory;
+import nginx.unit.websocket.Util;
+import nginx.unit.websocket.WsHandshakeResponse;
+import nginx.unit.websocket.pojo.PojoEndpointServer;
+
+public class UpgradeUtil {
+
+ private static final StringManager sm =
+ StringManager.getManager(UpgradeUtil.class.getPackage().getName());
+ private static final byte[] WS_ACCEPT =
+ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(
+ StandardCharsets.ISO_8859_1);
+
+ private UpgradeUtil() {
+ // Utility class. Hide default constructor.
+ }
+
+ /**
+ * Checks to see if this is an HTTP request that includes a valid upgrade
+ * request to web socket.
+ * <p>
+ * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java
+ * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC
+ * 6455 section 4.1 requires that a WebSocket Upgrade uses GET.
+ * @param request The request to check if it is an HTTP upgrade request for
+ * a WebSocket connection
+ * @param response The response associated with the request
+ * @return <code>true</code> if the request includes a HTTP Upgrade request
+ * for the WebSocket protocol, otherwise <code>false</code>
+ */
+ public static boolean isWebSocketUpgradeRequest(ServletRequest request,
+ ServletResponse response) {
+
+ Request r = (Request) request.getAttribute(Request.BARE);
+
+ return ((request instanceof HttpServletRequest) &&
+ (response instanceof HttpServletResponse) &&
+ (r != null) &&
+ (r.isUpgrade()));
+ }
+
+
+ public static void doUpgrade(WsServerContainer sc, HttpServletRequest req,
+ HttpServletResponse resp, ServerEndpointConfig sec,
+ Map<String,String> pathParams)
+ throws ServletException, IOException {
+
+
+ // Origin check
+ String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME);
+
+ if (!sec.getConfigurator().checkOrigin(origin)) {
+ resp.sendError(HttpServletResponse.SC_FORBIDDEN);
+ return;
+ }
+ // Sub-protocols
+ List<String> subProtocols = getTokensFromHeader(req,
+ Constants.WS_PROTOCOL_HEADER_NAME);
+ String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol(
+ sec.getSubprotocols(), subProtocols);
+
+ // Extensions
+ // Should normally only be one header but handle the case of multiple
+ // headers
+ List<Extension> extensionsRequested = new ArrayList<>();
+ Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME);
+ while (extHeaders.hasMoreElements()) {
+ Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement());
+ }
+
+ // Negotiation phase 1. By default this simply filters out the
+ // extensions that the server does not support but applications could
+ // use a custom configurator to do more than this.
+ List<Extension> installedExtensions = null;
+ if (sec.getExtensions().size() == 0) {
+ installedExtensions = Constants.INSTALLED_EXTENSIONS;
+ } else {
+ installedExtensions = new ArrayList<>();
+ installedExtensions.addAll(sec.getExtensions());
+ installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS);
+ }
+ List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions(
+ installedExtensions, extensionsRequested);
+
+ // Negotiation phase 2. Create the Transformations that will be applied
+ // to this connection. Note than an extension may be dropped at this
+ // point if the client has requested a configuration that the server is
+ // unable to support.
+ List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1);
+
+ List<Extension> negotiatedExtensionsPhase2;
+ if (transformations.isEmpty()) {
+ negotiatedExtensionsPhase2 = Collections.emptyList();
+ } else {
+ negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size());
+ for (Transformation t : transformations) {
+ negotiatedExtensionsPhase2.add(t.getExtensionResponse());
+ }
+ }
+
+ WsHttpUpgradeHandler wsHandler =
+ req.upgrade(WsHttpUpgradeHandler.class);
+
+ WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams);
+ WsHandshakeResponse wsResponse = new WsHandshakeResponse();
+ WsPerSessionServerEndpointConfig perSessionServerEndpointConfig =
+ new WsPerSessionServerEndpointConfig(sec);
+ sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig,
+ wsRequest, wsResponse);
+ //wsRequest.finished();
+
+ // Add any additional headers
+ for (Entry<String,List<String>> entry :
+ wsResponse.getHeaders().entrySet()) {
+ for (String headerValue: entry.getValue()) {
+ resp.addHeader(entry.getKey(), headerValue);
+ }
+ }
+
+ Endpoint ep;
+ try {
+ Class<?> clazz = sec.getEndpointClass();
+ if (Endpoint.class.isAssignableFrom(clazz)) {
+ ep = (Endpoint) sec.getConfigurator().getEndpointInstance(
+ clazz);
+ } else {
+ ep = new PojoEndpointServer();
+ // Need to make path params available to POJO
+ perSessionServerEndpointConfig.getUserProperties().put(
+ nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams);
+ }
+ } catch (InstantiationException e) {
+ throw new ServletException(e);
+ }
+
+ wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest,
+ negotiatedExtensionsPhase2, subProtocol, null, pathParams,
+ req.isSecure());
+
+ wsHandler.init(null);
+ }
+
+
+ private static List<Transformation> createTransformations(
+ List<Extension> negotiatedExtensions) {
+
+ TransformationFactory factory = TransformationFactory.getInstance();
+
+ LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences =
+ new LinkedHashMap<>();
+
+ // Result will likely be smaller than this
+ List<Transformation> result = new ArrayList<>(negotiatedExtensions.size());
+
+ for (Extension extension : negotiatedExtensions) {
+ List<List<Extension.Parameter>> preferences =
+ extensionPreferences.get(extension.getName());
+
+ if (preferences == null) {
+ preferences = new ArrayList<>();
+ extensionPreferences.put(extension.getName(), preferences);
+ }
+
+ preferences.add(extension.getParameters());
+ }
+
+ for (Map.Entry<String,List<List<Extension.Parameter>>> entry :
+ extensionPreferences.entrySet()) {
+ Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true);
+ if (transformation != null) {
+ result.add(transformation);
+ }
+ }
+ return result;
+ }
+
+
+ private static void append(StringBuilder sb, Extension extension) {
+ if (extension == null || extension.getName() == null || extension.getName().length() == 0) {
+ return;
+ }
+
+ sb.append(extension.getName());
+
+ for (Extension.Parameter p : extension.getParameters()) {
+ sb.append(';');
+ sb.append(p.getName());
+ if (p.getValue() != null) {
+ sb.append('=');
+ sb.append(p.getValue());
+ }
+ }
+ }
+
+
+ /*
+ * This only works for tokens. Quoted strings need more sophisticated
+ * parsing.
+ */
+ private static boolean headerContainsToken(HttpServletRequest req,
+ String headerName, String target) {
+ Enumeration<String> headers = req.getHeaders(headerName);
+ while (headers.hasMoreElements()) {
+ String header = headers.nextElement();
+ String[] tokens = header.split(",");
+ for (String token : tokens) {
+ if (target.equalsIgnoreCase(token.trim())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+
+ /*
+ * This only works for tokens. Quoted strings need more sophisticated
+ * parsing.
+ */
+ private static List<String> getTokensFromHeader(HttpServletRequest req,
+ String headerName) {
+ List<String> result = new ArrayList<>();
+ Enumeration<String> headers = req.getHeaders(headerName);
+ while (headers.hasMoreElements()) {
+ String header = headers.nextElement();
+ String[] tokens = header.split(",");
+ for (String token : tokens) {
+ result.add(token.trim());
+ }
+ }
+ return result;
+ }
+
+
+ private static String getWebSocketAccept(String key) {
+ byte[] digest = ConcurrentMessageDigest.digestSHA1(
+ key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT);
+ return Base64.encodeBase64String(digest);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/UriTemplate.java b/src/java/nginx/unit/websocket/server/UriTemplate.java
new file mode 100644
index 00000000..7877fac9
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/UriTemplate.java
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import javax.websocket.DeploymentException;
+
+import org.apache.tomcat.util.res.StringManager;
+
+/**
+ * Extracts path parameters from URIs used to create web socket connections
+ * using the URI template defined for the associated Endpoint.
+ */
+public class UriTemplate {
+
+ private static final StringManager sm = StringManager.getManager(UriTemplate.class);
+
+ private final String normalized;
+ private final List<Segment> segments = new ArrayList<>();
+ private final boolean hasParameters;
+
+
+ public UriTemplate(String path) throws DeploymentException {
+
+ if (path == null || path.length() ==0 || !path.startsWith("/")) {
+ throw new DeploymentException(
+ sm.getString("uriTemplate.invalidPath", path));
+ }
+
+ StringBuilder normalized = new StringBuilder(path.length());
+ Set<String> paramNames = new HashSet<>();
+
+ // Include empty segments.
+ String[] segments = path.split("/", -1);
+ int paramCount = 0;
+ int segmentCount = 0;
+
+ for (int i = 0; i < segments.length; i++) {
+ String segment = segments[i];
+ if (segment.length() == 0) {
+ if (i == 0 || (i == segments.length - 1 && paramCount == 0)) {
+ // Ignore the first empty segment as the path must always
+ // start with '/'
+ // Ending with a '/' is also OK for instances used for
+ // matches but not for parameterised templates.
+ continue;
+ } else {
+ // As per EG discussion, all other empty segments are
+ // invalid
+ throw new IllegalArgumentException(sm.getString(
+ "uriTemplate.emptySegment", path));
+ }
+ }
+ normalized.append('/');
+ int index = -1;
+ if (segment.startsWith("{") && segment.endsWith("}")) {
+ index = segmentCount;
+ segment = segment.substring(1, segment.length() - 1);
+ normalized.append('{');
+ normalized.append(paramCount++);
+ normalized.append('}');
+ if (!paramNames.add(segment)) {
+ throw new IllegalArgumentException(sm.getString(
+ "uriTemplate.duplicateParameter", segment));
+ }
+ } else {
+ if (segment.contains("{") || segment.contains("}")) {
+ throw new IllegalArgumentException(sm.getString(
+ "uriTemplate.invalidSegment", segment, path));
+ }
+ normalized.append(segment);
+ }
+ this.segments.add(new Segment(index, segment));
+ segmentCount++;
+ }
+
+ this.normalized = normalized.toString();
+ this.hasParameters = paramCount > 0;
+ }
+
+
+ public Map<String,String> match(UriTemplate candidate) {
+
+ Map<String,String> result = new HashMap<>();
+
+ // Should not happen but for safety
+ if (candidate.getSegmentCount() != getSegmentCount()) {
+ return null;
+ }
+
+ Iterator<Segment> candidateSegments =
+ candidate.getSegments().iterator();
+ Iterator<Segment> targetSegments = segments.iterator();
+
+ while (candidateSegments.hasNext()) {
+ Segment candidateSegment = candidateSegments.next();
+ Segment targetSegment = targetSegments.next();
+
+ if (targetSegment.getParameterIndex() == -1) {
+ // Not a parameter - values must match
+ if (!targetSegment.getValue().equals(
+ candidateSegment.getValue())) {
+ // Not a match. Stop here
+ return null;
+ }
+ } else {
+ // Parameter
+ result.put(targetSegment.getValue(),
+ candidateSegment.getValue());
+ }
+ }
+
+ return result;
+ }
+
+
+ public boolean hasParameters() {
+ return hasParameters;
+ }
+
+
+ public int getSegmentCount() {
+ return segments.size();
+ }
+
+
+ public String getNormalizedPath() {
+ return normalized;
+ }
+
+
+ private List<Segment> getSegments() {
+ return segments;
+ }
+
+
+ private static class Segment {
+ private final int parameterIndex;
+ private final String value;
+
+ public Segment(int parameterIndex, String value) {
+ this.parameterIndex = parameterIndex;
+ this.value = value;
+ }
+
+
+ public int getParameterIndex() {
+ return parameterIndex;
+ }
+
+
+ public String getValue() {
+ return value;
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsContextListener.java b/src/java/nginx/unit/websocket/server/WsContextListener.java
new file mode 100644
index 00000000..07137856
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsContextListener.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import javax.servlet.ServletContext;
+import javax.servlet.ServletContextEvent;
+import javax.servlet.ServletContextListener;
+
+/**
+ * In normal usage, this {@link ServletContextListener} does not need to be
+ * explicitly configured as the {@link WsSci} performs all the necessary
+ * bootstrap and installs this listener in the {@link ServletContext}. If the
+ * {@link WsSci} is disabled, this listener must be added manually to every
+ * {@link ServletContext} that uses WebSocket to bootstrap the
+ * {@link WsServerContainer} correctly.
+ */
+public class WsContextListener implements ServletContextListener {
+
+ @Override
+ public void contextInitialized(ServletContextEvent sce) {
+ ServletContext sc = sce.getServletContext();
+ // Don't trigger WebSocket initialization if a WebSocket Server
+ // Container is already present
+ if (sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE) == null) {
+ WsSci.init(sce.getServletContext(), false);
+ }
+ }
+
+ @Override
+ public void contextDestroyed(ServletContextEvent sce) {
+ ServletContext sc = sce.getServletContext();
+ Object obj = sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
+ if (obj instanceof WsServerContainer) {
+ ((WsServerContainer) obj).destroy();
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsFilter.java b/src/java/nginx/unit/websocket/server/WsFilter.java
new file mode 100644
index 00000000..abea71fc
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsFilter.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.io.IOException;
+
+import javax.servlet.FilterChain;
+import javax.servlet.GenericFilter;
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+
+/**
+ * Handles the initial HTTP connection for WebSocket connections.
+ */
+public class WsFilter extends GenericFilter {
+
+ private static final long serialVersionUID = 1L;
+
+ private transient WsServerContainer sc;
+
+
+ @Override
+ public void init() throws ServletException {
+ sc = (WsServerContainer) getServletContext().getAttribute(
+ Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE);
+ }
+
+
+ @Override
+ public void doFilter(ServletRequest request, ServletResponse response,
+ FilterChain chain) throws IOException, ServletException {
+
+ // This filter only needs to handle WebSocket upgrade requests
+ if (!sc.areEndpointsRegistered() ||
+ !UpgradeUtil.isWebSocketUpgradeRequest(request, response)) {
+ chain.doFilter(request, response);
+ return;
+ }
+
+ // HTTP request with an upgrade header for WebSocket present
+ HttpServletRequest req = (HttpServletRequest) request;
+ HttpServletResponse resp = (HttpServletResponse) response;
+
+ // Check to see if this WebSocket implementation has a matching mapping
+ String path;
+ String pathInfo = req.getPathInfo();
+ if (pathInfo == null) {
+ path = req.getServletPath();
+ } else {
+ path = req.getServletPath() + pathInfo;
+ }
+ WsMappingResult mappingResult = sc.findMapping(path);
+
+ if (mappingResult == null) {
+ // No endpoint registered for the requested path. Let the
+ // application handle it (it might redirect or forward for example)
+ chain.doFilter(request, response);
+ return;
+ }
+
+ UpgradeUtil.doUpgrade(sc, req, resp, mappingResult.getConfig(),
+ mappingResult.getPathParams());
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java b/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java
new file mode 100644
index 00000000..fa774302
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.net.URI;
+import java.net.URISyntaxException;
+import java.security.Principal;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+import javax.servlet.http.HttpServletRequest;
+import javax.websocket.server.HandshakeRequest;
+
+import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap;
+import org.apache.tomcat.util.res.StringManager;
+
+/**
+ * Represents the request that this session was opened under.
+ */
+public class WsHandshakeRequest implements HandshakeRequest {
+
+ private static final StringManager sm = StringManager.getManager(WsHandshakeRequest.class);
+
+ private final URI requestUri;
+ private final Map<String,List<String>> parameterMap;
+ private final String queryString;
+ private final Principal userPrincipal;
+ private final Map<String,List<String>> headers;
+ private final Object httpSession;
+
+ private volatile HttpServletRequest request;
+
+
+ public WsHandshakeRequest(HttpServletRequest request, Map<String,String> pathParams) {
+
+ this.request = request;
+
+ queryString = request.getQueryString();
+ userPrincipal = request.getUserPrincipal();
+ httpSession = request.getSession(false);
+ requestUri = buildRequestUri(request);
+
+ // ParameterMap
+ Map<String,String[]> originalParameters = request.getParameterMap();
+ Map<String,List<String>> newParameters =
+ new HashMap<>(originalParameters.size());
+ for (Entry<String,String[]> entry : originalParameters.entrySet()) {
+ newParameters.put(entry.getKey(),
+ Collections.unmodifiableList(
+ Arrays.asList(entry.getValue())));
+ }
+ for (Entry<String,String> entry : pathParams.entrySet()) {
+ newParameters.put(entry.getKey(),
+ Collections.unmodifiableList(
+ Collections.singletonList(entry.getValue())));
+ }
+ parameterMap = Collections.unmodifiableMap(newParameters);
+
+ // Headers
+ Map<String,List<String>> newHeaders = new CaseInsensitiveKeyMap<>();
+
+ Enumeration<String> headerNames = request.getHeaderNames();
+ while (headerNames.hasMoreElements()) {
+ String headerName = headerNames.nextElement();
+
+ newHeaders.put(headerName, Collections.unmodifiableList(
+ Collections.list(request.getHeaders(headerName))));
+ }
+
+ headers = Collections.unmodifiableMap(newHeaders);
+ }
+
+ @Override
+ public URI getRequestURI() {
+ return requestUri;
+ }
+
+ @Override
+ public Map<String,List<String>> getParameterMap() {
+ return parameterMap;
+ }
+
+ @Override
+ public String getQueryString() {
+ return queryString;
+ }
+
+ @Override
+ public Principal getUserPrincipal() {
+ return userPrincipal;
+ }
+
+ @Override
+ public Map<String,List<String>> getHeaders() {
+ return headers;
+ }
+
+ @Override
+ public boolean isUserInRole(String role) {
+ if (request == null) {
+ throw new IllegalStateException();
+ }
+
+ return request.isUserInRole(role);
+ }
+
+ @Override
+ public Object getHttpSession() {
+ return httpSession;
+ }
+
+ /**
+ * Called when the HandshakeRequest is no longer required. Since an instance
+ * of this class retains a reference to the current HttpServletRequest that
+ * reference needs to be cleared as the HttpServletRequest may be reused.
+ *
+ * There is no reason for instances of this class to be accessed once the
+ * handshake has been completed.
+ */
+ void finished() {
+ request = null;
+ }
+
+
+ /*
+ * See RequestUtil.getRequestURL()
+ */
+ private static URI buildRequestUri(HttpServletRequest req) {
+
+ StringBuffer uri = new StringBuffer();
+ String scheme = req.getScheme();
+ int port = req.getServerPort();
+ if (port < 0) {
+ // Work around java.net.URL bug
+ port = 80;
+ }
+
+ if ("http".equals(scheme)) {
+ uri.append("ws");
+ } else if ("https".equals(scheme)) {
+ uri.append("wss");
+ } else {
+ // Should never happen
+ throw new IllegalArgumentException(
+ sm.getString("wsHandshakeRequest.unknownScheme", scheme));
+ }
+
+ uri.append("://");
+ uri.append(req.getServerName());
+
+ if ((scheme.equals("http") && (port != 80))
+ || (scheme.equals("https") && (port != 443))) {
+ uri.append(':');
+ uri.append(port);
+ }
+
+ uri.append(req.getRequestURI());
+
+ if (req.getQueryString() != null) {
+ uri.append("?");
+ uri.append(req.getQueryString());
+ }
+
+ try {
+ return new URI(uri.toString());
+ } catch (URISyntaxException e) {
+ // Should never happen
+ throw new IllegalArgumentException(
+ sm.getString("wsHandshakeRequest.invalidUri", uri.toString()), e);
+ }
+ }
+
+ public Object getAttribute(String name)
+ {
+ return request != null ? request.getAttribute(name) : null;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java b/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java
new file mode 100644
index 00000000..cc39ab73
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+
+import javax.servlet.http.HttpSession;
+import javax.servlet.http.HttpUpgradeHandler;
+import javax.servlet.http.WebConnection;
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.DeploymentException;
+import javax.websocket.Endpoint;
+import javax.websocket.EndpointConfig;
+import javax.websocket.Extension;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.util.res.StringManager;
+
+import nginx.unit.websocket.Transformation;
+import nginx.unit.websocket.WsIOException;
+import nginx.unit.websocket.WsSession;
+
+import nginx.unit.Request;
+
+/**
+ * Servlet 3.1 HTTP upgrade handler for WebSocket connections.
+ */
+public class WsHttpUpgradeHandler implements HttpUpgradeHandler {
+
+ private final Log log = LogFactory.getLog(WsHttpUpgradeHandler.class); // must not be static
+ private static final StringManager sm = StringManager.getManager(WsHttpUpgradeHandler.class);
+
+ private final ClassLoader applicationClassLoader;
+
+ private Endpoint ep;
+ private EndpointConfig endpointConfig;
+ private WsServerContainer webSocketContainer;
+ private WsHandshakeRequest handshakeRequest;
+ private List<Extension> negotiatedExtensions;
+ private String subProtocol;
+ private Transformation transformation;
+ private Map<String,String> pathParameters;
+ private boolean secure;
+ private WebConnection connection;
+ private WsRemoteEndpointImplServer wsRemoteEndpointServer;
+ private WsSession wsSession;
+
+
+ public WsHttpUpgradeHandler() {
+ applicationClassLoader = Thread.currentThread().getContextClassLoader();
+ }
+
+ public void preInit(Endpoint ep, EndpointConfig endpointConfig,
+ WsServerContainer wsc, WsHandshakeRequest handshakeRequest,
+ List<Extension> negotiatedExtensionsPhase2, String subProtocol,
+ Transformation transformation, Map<String,String> pathParameters,
+ boolean secure) {
+ this.ep = ep;
+ this.endpointConfig = endpointConfig;
+ this.webSocketContainer = wsc;
+ this.handshakeRequest = handshakeRequest;
+ this.negotiatedExtensions = negotiatedExtensionsPhase2;
+ this.subProtocol = subProtocol;
+ this.transformation = transformation;
+ this.pathParameters = pathParameters;
+ this.secure = secure;
+ }
+
+
+ @Override
+ public void init(WebConnection connection) {
+ if (ep == null) {
+ throw new IllegalStateException(
+ sm.getString("wsHttpUpgradeHandler.noPreInit"));
+ }
+
+ String httpSessionId = null;
+ Object session = handshakeRequest.getHttpSession();
+ if (session != null ) {
+ httpSessionId = ((HttpSession) session).getId();
+ }
+
+ nginx.unit.Context.trace("UpgradeHandler.init(" + connection + ")");
+
+/*
+ // Need to call onOpen using the web application's class loader
+ // Create the frame using the application's class loader so it can pick
+ // up application specific config from the ServerContainerImpl
+ Thread t = Thread.currentThread();
+ ClassLoader cl = t.getContextClassLoader();
+ t.setContextClassLoader(applicationClassLoader);
+*/
+ try {
+ Request r = (Request) handshakeRequest.getAttribute(Request.BARE);
+
+ wsRemoteEndpointServer = new WsRemoteEndpointImplServer(webSocketContainer);
+ wsSession = new WsSession(ep, wsRemoteEndpointServer,
+ webSocketContainer, handshakeRequest.getRequestURI(),
+ handshakeRequest.getParameterMap(),
+ handshakeRequest.getQueryString(),
+ handshakeRequest.getUserPrincipal(), httpSessionId,
+ negotiatedExtensions, subProtocol, pathParameters, secure,
+ endpointConfig, r);
+
+ ep.onOpen(wsSession, endpointConfig);
+ webSocketContainer.registerSession(ep, wsSession);
+ } catch (DeploymentException e) {
+ throw new IllegalArgumentException(e);
+/*
+ } finally {
+ t.setContextClassLoader(cl);
+*/
+ }
+ }
+
+
+
+ @Override
+ public void destroy() {
+ if (connection != null) {
+ try {
+ connection.close();
+ } catch (Exception e) {
+ log.error(sm.getString("wsHttpUpgradeHandler.destroyFailed"), e);
+ }
+ }
+ }
+
+
+ private void onError(Throwable throwable) {
+ // Need to call onError using the web application's class loader
+ Thread t = Thread.currentThread();
+ ClassLoader cl = t.getContextClassLoader();
+ t.setContextClassLoader(applicationClassLoader);
+ try {
+ ep.onError(wsSession, throwable);
+ } finally {
+ t.setContextClassLoader(cl);
+ }
+ }
+
+
+ private void close(CloseReason cr) {
+ /*
+ * Any call to this method is a result of a problem reading from the
+ * client. At this point that state of the connection is unknown.
+ * Attempt to send a close frame to the client and then close the socket
+ * immediately. There is no point in waiting for a close frame from the
+ * client because there is no guarantee that we can recover from
+ * whatever messed up state the client put the connection into.
+ */
+ wsSession.onClose(cr);
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsMappingResult.java b/src/java/nginx/unit/websocket/server/WsMappingResult.java
new file mode 100644
index 00000000..a7a4c022
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsMappingResult.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.util.Map;
+
+import javax.websocket.server.ServerEndpointConfig;
+
+class WsMappingResult {
+
+ private final ServerEndpointConfig config;
+ private final Map<String,String> pathParams;
+
+
+ WsMappingResult(ServerEndpointConfig config,
+ Map<String,String> pathParams) {
+ this.config = config;
+ this.pathParams = pathParams;
+ }
+
+
+ ServerEndpointConfig getConfig() {
+ return config;
+ }
+
+
+ Map<String,String> getPathParams() {
+ return pathParams;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java b/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java
new file mode 100644
index 00000000..2be050cb
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import javax.websocket.Decoder;
+import javax.websocket.Encoder;
+import javax.websocket.Extension;
+import javax.websocket.server.ServerEndpointConfig;
+
+/**
+ * Wraps the provided {@link ServerEndpointConfig} and provides a per session
+ * view - the difference being that the map returned by {@link
+ * #getUserProperties()} is unique to this instance rather than shared with the
+ * wrapped {@link ServerEndpointConfig}.
+ */
+class WsPerSessionServerEndpointConfig implements ServerEndpointConfig {
+
+ private final ServerEndpointConfig perEndpointConfig;
+ private final Map<String,Object> perSessionUserProperties =
+ new ConcurrentHashMap<>();
+
+ WsPerSessionServerEndpointConfig(ServerEndpointConfig perEndpointConfig) {
+ this.perEndpointConfig = perEndpointConfig;
+ perSessionUserProperties.putAll(perEndpointConfig.getUserProperties());
+ }
+
+ @Override
+ public List<Class<? extends Encoder>> getEncoders() {
+ return perEndpointConfig.getEncoders();
+ }
+
+ @Override
+ public List<Class<? extends Decoder>> getDecoders() {
+ return perEndpointConfig.getDecoders();
+ }
+
+ @Override
+ public Map<String,Object> getUserProperties() {
+ return perSessionUserProperties;
+ }
+
+ @Override
+ public Class<?> getEndpointClass() {
+ return perEndpointConfig.getEndpointClass();
+ }
+
+ @Override
+ public String getPath() {
+ return perEndpointConfig.getPath();
+ }
+
+ @Override
+ public List<String> getSubprotocols() {
+ return perEndpointConfig.getSubprotocols();
+ }
+
+ @Override
+ public List<Extension> getExtensions() {
+ return perEndpointConfig.getExtensions();
+ }
+
+ @Override
+ public Configurator getConfigurator() {
+ return perEndpointConfig.getConfigurator();
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java b/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java
new file mode 100644
index 00000000..6d10a3be
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.net.SocketTimeoutException;
+import java.nio.ByteBuffer;
+import java.nio.channels.CompletionHandler;
+import java.nio.channels.InterruptedByTimeoutException;
+import java.util.concurrent.RejectedExecutionException;
+import java.util.concurrent.TimeUnit;
+
+import javax.websocket.SendHandler;
+import javax.websocket.SendResult;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.util.res.StringManager;
+import nginx.unit.websocket.Transformation;
+import nginx.unit.websocket.WsRemoteEndpointImplBase;
+
+/**
+ * This is the server side {@link javax.websocket.RemoteEndpoint} implementation
+ * - i.e. what the server uses to send data to the client.
+ */
+public class WsRemoteEndpointImplServer extends WsRemoteEndpointImplBase {
+
+ private static final StringManager sm =
+ StringManager.getManager(WsRemoteEndpointImplServer.class);
+ private final Log log = LogFactory.getLog(WsRemoteEndpointImplServer.class); // must not be static
+
+ private volatile SendHandler handler = null;
+ private volatile ByteBuffer[] buffers = null;
+
+ private volatile long timeoutExpiry = -1;
+ private volatile boolean close;
+
+ public WsRemoteEndpointImplServer(
+ WsServerContainer serverContainer) {
+ }
+
+
+ @Override
+ protected final boolean isMasked() {
+ return false;
+ }
+
+ @Override
+ protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry,
+ ByteBuffer... buffers) {
+ }
+
+ @Override
+ protected void doClose() {
+ if (handler != null) {
+ // close() can be triggered by a wide range of scenarios. It is far
+ // simpler just to always use a dispatch than it is to try and track
+ // whether or not this method was called by the same thread that
+ // triggered the write
+ clearHandler(new EOFException(), true);
+ }
+ }
+
+
+ protected long getTimeoutExpiry() {
+ return timeoutExpiry;
+ }
+
+
+ /*
+ * Currently this is only called from the background thread so we could just
+ * call clearHandler() with useDispatch == false but the method parameter
+ * was added in case other callers started to use this method to make sure
+ * that those callers think through what the correct value of useDispatch is
+ * for them.
+ */
+ protected void onTimeout(boolean useDispatch) {
+ if (handler != null) {
+ clearHandler(new SocketTimeoutException(), useDispatch);
+ }
+ close();
+ }
+
+
+ @Override
+ protected void setTransformation(Transformation transformation) {
+ // Overridden purely so it is visible to other classes in this package
+ super.setTransformation(transformation);
+ }
+
+
+ /**
+ *
+ * @param t The throwable associated with any error that
+ * occurred
+ * @param useDispatch Should {@link SendHandler#onResult(SendResult)} be
+ * called from a new thread, keeping in mind the
+ * requirements of
+ * {@link javax.websocket.RemoteEndpoint.Async}
+ */
+ private void clearHandler(Throwable t, boolean useDispatch) {
+ // Setting the result marks this (partial) message as
+ // complete which means the next one may be sent which
+ // could update the value of the handler. Therefore, keep a
+ // local copy before signalling the end of the (partial)
+ // message.
+ SendHandler sh = handler;
+ handler = null;
+ buffers = null;
+ if (sh != null) {
+ if (useDispatch) {
+ OnResultRunnable r = new OnResultRunnable(sh, t);
+ } else {
+ if (t == null) {
+ sh.onResult(new SendResult());
+ } else {
+ sh.onResult(new SendResult(t));
+ }
+ }
+ }
+ }
+
+
+ private static class OnResultRunnable implements Runnable {
+
+ private final SendHandler sh;
+ private final Throwable t;
+
+ private OnResultRunnable(SendHandler sh, Throwable t) {
+ this.sh = sh;
+ this.t = t;
+ }
+
+ @Override
+ public void run() {
+ if (t == null) {
+ sh.onResult(new SendResult());
+ } else {
+ sh.onResult(new SendResult(t));
+ }
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsSci.java b/src/java/nginx/unit/websocket/server/WsSci.java
new file mode 100644
index 00000000..cdecce27
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsSci.java
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.lang.reflect.Modifier;
+import java.util.HashSet;
+import java.util.Set;
+
+import javax.servlet.ServletContainerInitializer;
+import javax.servlet.ServletContext;
+import javax.servlet.ServletException;
+import javax.servlet.annotation.HandlesTypes;
+import javax.websocket.ContainerProvider;
+import javax.websocket.DeploymentException;
+import javax.websocket.Endpoint;
+import javax.websocket.server.ServerApplicationConfig;
+import javax.websocket.server.ServerEndpoint;
+import javax.websocket.server.ServerEndpointConfig;
+
+/**
+ * Registers an interest in any class that is annotated with
+ * {@link ServerEndpoint} so that Endpoint can be published via the WebSocket
+ * server.
+ */
+@HandlesTypes({ServerEndpoint.class, ServerApplicationConfig.class,
+ Endpoint.class})
+public class WsSci implements ServletContainerInitializer {
+
+ @Override
+ public void onStartup(Set<Class<?>> clazzes, ServletContext ctx)
+ throws ServletException {
+
+ WsServerContainer sc = init(ctx, true);
+
+ if (clazzes == null || clazzes.size() == 0) {
+ return;
+ }
+
+ // Group the discovered classes by type
+ Set<ServerApplicationConfig> serverApplicationConfigs = new HashSet<>();
+ Set<Class<? extends Endpoint>> scannedEndpointClazzes = new HashSet<>();
+ Set<Class<?>> scannedPojoEndpoints = new HashSet<>();
+
+ try {
+ // wsPackage is "javax.websocket."
+ String wsPackage = ContainerProvider.class.getName();
+ wsPackage = wsPackage.substring(0, wsPackage.lastIndexOf('.') + 1);
+ for (Class<?> clazz : clazzes) {
+ int modifiers = clazz.getModifiers();
+ if (!Modifier.isPublic(modifiers) ||
+ Modifier.isAbstract(modifiers)) {
+ // Non-public or abstract - skip it.
+ continue;
+ }
+ // Protect against scanning the WebSocket API JARs
+ if (clazz.getName().startsWith(wsPackage)) {
+ continue;
+ }
+ if (ServerApplicationConfig.class.isAssignableFrom(clazz)) {
+ serverApplicationConfigs.add(
+ (ServerApplicationConfig) clazz.getConstructor().newInstance());
+ }
+ if (Endpoint.class.isAssignableFrom(clazz)) {
+ @SuppressWarnings("unchecked")
+ Class<? extends Endpoint> endpoint =
+ (Class<? extends Endpoint>) clazz;
+ scannedEndpointClazzes.add(endpoint);
+ }
+ if (clazz.isAnnotationPresent(ServerEndpoint.class)) {
+ scannedPojoEndpoints.add(clazz);
+ }
+ }
+ } catch (ReflectiveOperationException e) {
+ throw new ServletException(e);
+ }
+
+ // Filter the results
+ Set<ServerEndpointConfig> filteredEndpointConfigs = new HashSet<>();
+ Set<Class<?>> filteredPojoEndpoints = new HashSet<>();
+
+ if (serverApplicationConfigs.isEmpty()) {
+ filteredPojoEndpoints.addAll(scannedPojoEndpoints);
+ } else {
+ for (ServerApplicationConfig config : serverApplicationConfigs) {
+ Set<ServerEndpointConfig> configFilteredEndpoints =
+ config.getEndpointConfigs(scannedEndpointClazzes);
+ if (configFilteredEndpoints != null) {
+ filteredEndpointConfigs.addAll(configFilteredEndpoints);
+ }
+ Set<Class<?>> configFilteredPojos =
+ config.getAnnotatedEndpointClasses(
+ scannedPojoEndpoints);
+ if (configFilteredPojos != null) {
+ filteredPojoEndpoints.addAll(configFilteredPojos);
+ }
+ }
+ }
+
+ try {
+ // Deploy endpoints
+ for (ServerEndpointConfig config : filteredEndpointConfigs) {
+ sc.addEndpoint(config);
+ }
+ // Deploy POJOs
+ for (Class<?> clazz : filteredPojoEndpoints) {
+ sc.addEndpoint(clazz);
+ }
+ } catch (DeploymentException e) {
+ throw new ServletException(e);
+ }
+ }
+
+
+ static WsServerContainer init(ServletContext servletContext,
+ boolean initBySciMechanism) {
+
+ WsServerContainer sc = new WsServerContainer(servletContext);
+
+ servletContext.setAttribute(
+ Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE, sc);
+
+ servletContext.addListener(new WsSessionListener(sc));
+ // Can't register the ContextListener again if the ContextListener is
+ // calling this method
+ if (initBySciMechanism) {
+ servletContext.addListener(new WsContextListener());
+ }
+
+ return sc;
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsServerContainer.java b/src/java/nginx/unit/websocket/server/WsServerContainer.java
new file mode 100644
index 00000000..069fc54f
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsServerContainer.java
@@ -0,0 +1,470 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.EnumSet;
+import java.util.Map;
+import java.util.Set;
+import java.util.SortedSet;
+import java.util.TreeSet;
+import java.util.concurrent.ConcurrentHashMap;
+
+import javax.servlet.DispatcherType;
+import javax.servlet.FilterRegistration;
+import javax.servlet.ServletContext;
+import javax.servlet.ServletException;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.DeploymentException;
+import javax.websocket.Encoder;
+import javax.websocket.Endpoint;
+import javax.websocket.server.ServerContainer;
+import javax.websocket.server.ServerEndpoint;
+import javax.websocket.server.ServerEndpointConfig;
+import javax.websocket.server.ServerEndpointConfig.Configurator;
+
+import org.apache.tomcat.InstanceManager;
+import org.apache.tomcat.util.res.StringManager;
+import nginx.unit.websocket.WsSession;
+import nginx.unit.websocket.WsWebSocketContainer;
+import nginx.unit.websocket.pojo.PojoMethodMapping;
+
+/**
+ * Provides a per class loader (i.e. per web application) instance of a
+ * ServerContainer. Web application wide defaults may be configured by setting
+ * the following servlet context initialisation parameters to the desired
+ * values.
+ * <ul>
+ * <li>{@link Constants#BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li>
+ * <li>{@link Constants#TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li>
+ * </ul>
+ */
+public class WsServerContainer extends WsWebSocketContainer
+ implements ServerContainer {
+
+ private static final StringManager sm = StringManager.getManager(WsServerContainer.class);
+
+ private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED =
+ new CloseReason(CloseCodes.VIOLATED_POLICY,
+ "This connection was established under an authenticated " +
+ "HTTP session that has ended.");
+
+ private final ServletContext servletContext;
+ private final Map<String,ServerEndpointConfig> configExactMatchMap =
+ new ConcurrentHashMap<>();
+ private final Map<Integer,SortedSet<TemplatePathMatch>> configTemplateMatchMap =
+ new ConcurrentHashMap<>();
+ private volatile boolean enforceNoAddAfterHandshake =
+ nginx.unit.websocket.Constants.STRICT_SPEC_COMPLIANCE;
+ private volatile boolean addAllowed = true;
+ private final Map<String,Set<WsSession>> authenticatedSessions = new ConcurrentHashMap<>();
+ private volatile boolean endpointsRegistered = false;
+
+ WsServerContainer(ServletContext servletContext) {
+
+ this.servletContext = servletContext;
+ setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName()));
+
+ // Configure servlet context wide defaults
+ String value = servletContext.getInitParameter(
+ Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM);
+ if (value != null) {
+ setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value));
+ }
+
+ value = servletContext.getInitParameter(
+ Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM);
+ if (value != null) {
+ setDefaultMaxTextMessageBufferSize(Integer.parseInt(value));
+ }
+
+ value = servletContext.getInitParameter(
+ Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM);
+ if (value != null) {
+ setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value));
+ }
+
+ FilterRegistration.Dynamic fr = servletContext.addFilter(
+ "Tomcat WebSocket (JSR356) Filter", new WsFilter());
+ fr.setAsyncSupported(true);
+
+ EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST,
+ DispatcherType.FORWARD);
+
+ fr.addMappingForUrlPatterns(types, true, "/*");
+ }
+
+
+ /**
+ * Published the provided endpoint implementation at the specified path with
+ * the specified configuration. {@link #WsServerContainer(ServletContext)}
+ * must be called before calling this method.
+ *
+ * @param sec The configuration to use when creating endpoint instances
+ * @throws DeploymentException if the endpoint cannot be published as
+ * requested
+ */
+ @Override
+ public void addEndpoint(ServerEndpointConfig sec)
+ throws DeploymentException {
+
+ if (enforceNoAddAfterHandshake && !addAllowed) {
+ throw new DeploymentException(
+ sm.getString("serverContainer.addNotAllowed"));
+ }
+
+ if (servletContext == null) {
+ throw new DeploymentException(
+ sm.getString("serverContainer.servletContextMissing"));
+ }
+ String path = sec.getPath();
+
+ // Add method mapping to user properties
+ PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(),
+ sec.getDecoders(), path);
+ if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null
+ || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) {
+ sec.getUserProperties().put(nginx.unit.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY,
+ methodMapping);
+ }
+
+ UriTemplate uriTemplate = new UriTemplate(path);
+ if (uriTemplate.hasParameters()) {
+ Integer key = Integer.valueOf(uriTemplate.getSegmentCount());
+ SortedSet<TemplatePathMatch> templateMatches =
+ configTemplateMatchMap.get(key);
+ if (templateMatches == null) {
+ // Ensure that if concurrent threads execute this block they
+ // both end up using the same TreeSet instance
+ templateMatches = new TreeSet<>(
+ TemplatePathMatchComparator.getInstance());
+ configTemplateMatchMap.putIfAbsent(key, templateMatches);
+ templateMatches = configTemplateMatchMap.get(key);
+ }
+ if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) {
+ // Duplicate uriTemplate;
+ throw new DeploymentException(
+ sm.getString("serverContainer.duplicatePaths", path,
+ sec.getEndpointClass(),
+ sec.getEndpointClass()));
+ }
+ } else {
+ // Exact match
+ ServerEndpointConfig old = configExactMatchMap.put(path, sec);
+ if (old != null) {
+ // Duplicate path mappings
+ throw new DeploymentException(
+ sm.getString("serverContainer.duplicatePaths", path,
+ old.getEndpointClass(),
+ sec.getEndpointClass()));
+ }
+ }
+
+ endpointsRegistered = true;
+ }
+
+
+ /**
+ * Provides the equivalent of {@link #addEndpoint(ServerEndpointConfig)}
+ * for publishing plain old java objects (POJOs) that have been annotated as
+ * WebSocket endpoints.
+ *
+ * @param pojo The annotated POJO
+ */
+ @Override
+ public void addEndpoint(Class<?> pojo) throws DeploymentException {
+
+ ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class);
+ if (annotation == null) {
+ throw new DeploymentException(
+ sm.getString("serverContainer.missingAnnotation",
+ pojo.getName()));
+ }
+ String path = annotation.value();
+
+ // Validate encoders
+ validateEncoders(annotation.encoders());
+
+ // ServerEndpointConfig
+ ServerEndpointConfig sec;
+ Class<? extends Configurator> configuratorClazz =
+ annotation.configurator();
+ Configurator configurator = null;
+ if (!configuratorClazz.equals(Configurator.class)) {
+ try {
+ configurator = annotation.configurator().getConstructor().newInstance();
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(sm.getString(
+ "serverContainer.configuratorFail",
+ annotation.configurator().getName(),
+ pojo.getClass().getName()), e);
+ }
+ }
+ if (configurator == null) {
+ configurator = new nginx.unit.websocket.server.DefaultServerEndpointConfigurator();
+ }
+ sec = ServerEndpointConfig.Builder.create(pojo, path).
+ decoders(Arrays.asList(annotation.decoders())).
+ encoders(Arrays.asList(annotation.encoders())).
+ subprotocols(Arrays.asList(annotation.subprotocols())).
+ configurator(configurator).
+ build();
+
+ addEndpoint(sec);
+ }
+
+
+ boolean areEndpointsRegistered() {
+ return endpointsRegistered;
+ }
+
+
+ /**
+ * Until the WebSocket specification provides such a mechanism, this Tomcat
+ * proprietary method is provided to enable applications to programmatically
+ * determine whether or not to upgrade an individual request to WebSocket.
+ * <p>
+ * Note: This method is not used by Tomcat but is used directly by
+ * third-party code and must not be removed.
+ *
+ * @param request The request object to be upgraded
+ * @param response The response object to be populated with the result of
+ * the upgrade
+ * @param sec The server endpoint to use to process the upgrade request
+ * @param pathParams The path parameters associated with the upgrade request
+ *
+ * @throws ServletException If a configuration error prevents the upgrade
+ * from taking place
+ * @throws IOException If an I/O error occurs during the upgrade process
+ */
+ public void doUpgrade(HttpServletRequest request,
+ HttpServletResponse response, ServerEndpointConfig sec,
+ Map<String,String> pathParams)
+ throws ServletException, IOException {
+ UpgradeUtil.doUpgrade(this, request, response, sec, pathParams);
+ }
+
+
+ public WsMappingResult findMapping(String path) {
+
+ // Prevent registering additional endpoints once the first attempt has
+ // been made to use one
+ if (addAllowed) {
+ addAllowed = false;
+ }
+
+ // Check an exact match. Simple case as there are no templates.
+ ServerEndpointConfig sec = configExactMatchMap.get(path);
+ if (sec != null) {
+ return new WsMappingResult(sec, Collections.<String, String>emptyMap());
+ }
+
+ // No exact match. Need to look for template matches.
+ UriTemplate pathUriTemplate = null;
+ try {
+ pathUriTemplate = new UriTemplate(path);
+ } catch (DeploymentException e) {
+ // Path is not valid so can't be matched to a WebSocketEndpoint
+ return null;
+ }
+
+ // Number of segments has to match
+ Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount());
+ SortedSet<TemplatePathMatch> templateMatches =
+ configTemplateMatchMap.get(key);
+
+ if (templateMatches == null) {
+ // No templates with an equal number of segments so there will be
+ // no matches
+ return null;
+ }
+
+ // List is in alphabetical order of normalised templates.
+ // Correct match is the first one that matches.
+ Map<String,String> pathParams = null;
+ for (TemplatePathMatch templateMatch : templateMatches) {
+ pathParams = templateMatch.getUriTemplate().match(pathUriTemplate);
+ if (pathParams != null) {
+ sec = templateMatch.getConfig();
+ break;
+ }
+ }
+
+ if (sec == null) {
+ // No match
+ return null;
+ }
+
+ return new WsMappingResult(sec, pathParams);
+ }
+
+
+
+ public boolean isEnforceNoAddAfterHandshake() {
+ return enforceNoAddAfterHandshake;
+ }
+
+
+ public void setEnforceNoAddAfterHandshake(
+ boolean enforceNoAddAfterHandshake) {
+ this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * Overridden to make it visible to other classes in this package.
+ */
+ @Override
+ protected void registerSession(Endpoint endpoint, WsSession wsSession) {
+ super.registerSession(endpoint, wsSession);
+ if (wsSession.isOpen() &&
+ wsSession.getUserPrincipal() != null &&
+ wsSession.getHttpSessionId() != null) {
+ registerAuthenticatedSession(wsSession,
+ wsSession.getHttpSessionId());
+ }
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * Overridden to make it visible to other classes in this package.
+ */
+ @Override
+ protected void unregisterSession(Endpoint endpoint, WsSession wsSession) {
+ if (wsSession.getUserPrincipal() != null &&
+ wsSession.getHttpSessionId() != null) {
+ unregisterAuthenticatedSession(wsSession,
+ wsSession.getHttpSessionId());
+ }
+ super.unregisterSession(endpoint, wsSession);
+ }
+
+
+ private void registerAuthenticatedSession(WsSession wsSession,
+ String httpSessionId) {
+ Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId);
+ if (wsSessions == null) {
+ wsSessions = Collections.newSetFromMap(
+ new ConcurrentHashMap<WsSession,Boolean>());
+ authenticatedSessions.putIfAbsent(httpSessionId, wsSessions);
+ wsSessions = authenticatedSessions.get(httpSessionId);
+ }
+ wsSessions.add(wsSession);
+ }
+
+
+ private void unregisterAuthenticatedSession(WsSession wsSession,
+ String httpSessionId) {
+ Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId);
+ // wsSessions will be null if the HTTP session has ended
+ if (wsSessions != null) {
+ wsSessions.remove(wsSession);
+ }
+ }
+
+
+ public void closeAuthenticatedSession(String httpSessionId) {
+ Set<WsSession> wsSessions = authenticatedSessions.remove(httpSessionId);
+
+ if (wsSessions != null && !wsSessions.isEmpty()) {
+ for (WsSession wsSession : wsSessions) {
+ try {
+ wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED);
+ } catch (IOException e) {
+ // Any IOExceptions during close will have been caught and the
+ // onError method called.
+ }
+ }
+ }
+ }
+
+
+ private static void validateEncoders(Class<? extends Encoder>[] encoders)
+ throws DeploymentException {
+
+ for (Class<? extends Encoder> encoder : encoders) {
+ // Need to instantiate decoder to ensure it is valid and that
+ // deployment can be failed if it is not
+ @SuppressWarnings("unused")
+ Encoder instance;
+ try {
+ encoder.getConstructor().newInstance();
+ } catch(ReflectiveOperationException e) {
+ throw new DeploymentException(sm.getString(
+ "serverContainer.encoderFail", encoder.getName()), e);
+ }
+ }
+ }
+
+
+ private static class TemplatePathMatch {
+ private final ServerEndpointConfig config;
+ private final UriTemplate uriTemplate;
+
+ public TemplatePathMatch(ServerEndpointConfig config,
+ UriTemplate uriTemplate) {
+ this.config = config;
+ this.uriTemplate = uriTemplate;
+ }
+
+
+ public ServerEndpointConfig getConfig() {
+ return config;
+ }
+
+
+ public UriTemplate getUriTemplate() {
+ return uriTemplate;
+ }
+ }
+
+
+ /**
+ * This Comparator implementation is thread-safe so only create a single
+ * instance.
+ */
+ private static class TemplatePathMatchComparator
+ implements Comparator<TemplatePathMatch> {
+
+ private static final TemplatePathMatchComparator INSTANCE =
+ new TemplatePathMatchComparator();
+
+ public static TemplatePathMatchComparator getInstance() {
+ return INSTANCE;
+ }
+
+ private TemplatePathMatchComparator() {
+ // Hide default constructor
+ }
+
+ @Override
+ public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) {
+ return tpm1.getUriTemplate().getNormalizedPath().compareTo(
+ tpm2.getUriTemplate().getNormalizedPath());
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsSessionListener.java b/src/java/nginx/unit/websocket/server/WsSessionListener.java
new file mode 100644
index 00000000..fc2bc9c5
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsSessionListener.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import javax.servlet.http.HttpSessionEvent;
+import javax.servlet.http.HttpSessionListener;
+
+public class WsSessionListener implements HttpSessionListener{
+
+ private final WsServerContainer wsServerContainer;
+
+
+ public WsSessionListener(WsServerContainer wsServerContainer) {
+ this.wsServerContainer = wsServerContainer;
+ }
+
+
+ @Override
+ public void sessionDestroyed(HttpSessionEvent se) {
+ wsServerContainer.closeAuthenticatedSession(se.getSession().getId());
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/WsWriteTimeout.java b/src/java/nginx/unit/websocket/server/WsWriteTimeout.java
new file mode 100644
index 00000000..2dfc4ab2
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/WsWriteTimeout.java
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.util.Comparator;
+import java.util.Set;
+import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import nginx.unit.websocket.BackgroundProcess;
+import nginx.unit.websocket.BackgroundProcessManager;
+
+/**
+ * Provides timeouts for asynchronous web socket writes. On the server side we
+ * only have access to {@link javax.servlet.ServletOutputStream} and
+ * {@link javax.servlet.ServletInputStream} so there is no way to set a timeout
+ * for writes to the client.
+ */
+public class WsWriteTimeout implements BackgroundProcess {
+
+ private final Set<WsRemoteEndpointImplServer> endpoints =
+ new ConcurrentSkipListSet<>(new EndpointComparator());
+ private final AtomicInteger count = new AtomicInteger(0);
+ private int backgroundProcessCount = 0;
+ private volatile int processPeriod = 1;
+
+ @Override
+ public void backgroundProcess() {
+ // This method gets called once a second.
+ backgroundProcessCount ++;
+
+ if (backgroundProcessCount >= processPeriod) {
+ backgroundProcessCount = 0;
+
+ long now = System.currentTimeMillis();
+ for (WsRemoteEndpointImplServer endpoint : endpoints) {
+ if (endpoint.getTimeoutExpiry() < now) {
+ // Background thread, not the thread that triggered the
+ // write so no need to use a dispatch
+ endpoint.onTimeout(false);
+ } else {
+ // Endpoints are ordered by timeout expiry so if this point
+ // is reached there is no need to check the remaining
+ // endpoints
+ break;
+ }
+ }
+ }
+ }
+
+
+ @Override
+ public void setProcessPeriod(int period) {
+ this.processPeriod = period;
+ }
+
+
+ /**
+ * {@inheritDoc}
+ *
+ * The default value is 1 which means asynchronous write timeouts are
+ * processed every 1 second.
+ */
+ @Override
+ public int getProcessPeriod() {
+ return processPeriod;
+ }
+
+
+ public void register(WsRemoteEndpointImplServer endpoint) {
+ boolean result = endpoints.add(endpoint);
+ if (result) {
+ int newCount = count.incrementAndGet();
+ if (newCount == 1) {
+ BackgroundProcessManager.getInstance().register(this);
+ }
+ }
+ }
+
+
+ public void unregister(WsRemoteEndpointImplServer endpoint) {
+ boolean result = endpoints.remove(endpoint);
+ if (result) {
+ int newCount = count.decrementAndGet();
+ if (newCount == 0) {
+ BackgroundProcessManager.getInstance().unregister(this);
+ }
+ }
+ }
+
+
+ /**
+ * Note: this comparator imposes orderings that are inconsistent with equals
+ */
+ private static class EndpointComparator implements
+ Comparator<WsRemoteEndpointImplServer> {
+
+ @Override
+ public int compare(WsRemoteEndpointImplServer o1,
+ WsRemoteEndpointImplServer o2) {
+
+ long t1 = o1.getTimeoutExpiry();
+ long t2 = o2.getTimeoutExpiry();
+
+ if (t1 < t2) {
+ return -1;
+ } else if (t1 == t2) {
+ return 0;
+ } else {
+ return 1;
+ }
+ }
+ }
+}
diff --git a/src/java/nginx/unit/websocket/server/package-info.java b/src/java/nginx/unit/websocket/server/package-info.java
new file mode 100644
index 00000000..87bc85a3
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/package-info.java
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+/**
+ * Server-side specific implementation classes. These are in a separate package
+ * to make packaging a pure client JAR simpler.
+ */
+package nginx.unit.websocket.server;