diff options
author | Konstantin Pavlov <thresh@nginx.com> | 2019-09-19 19:04:16 +0300 |
---|---|---|
committer | Konstantin Pavlov <thresh@nginx.com> | 2019-09-19 19:04:16 +0300 |
commit | deb26fa47a9ab1b358938134a8ced8bbc4a083e1 (patch) | |
tree | 0bedf8829f003fa4c0101e3421b7184acc1c8343 /src/java/nginx/unit/websocket | |
parent | fcb1f851d0b5d1774a6cb876288ea29cfef58618 (diff) | |
parent | db777d1e7f607d1b0f01dfb73ad0bac12987202b (diff) | |
download | unit-deb26fa47a9ab1b358938134a8ced8bbc4a083e1.tar.gz unit-deb26fa47a9ab1b358938134a8ced8bbc4a083e1.tar.bz2 |
Merged with the default branch.
Diffstat (limited to 'src/java/nginx/unit/websocket')
73 files changed, 12702 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java b/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java new file mode 100644 index 00000000..147112c1 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.channels.AsynchronousChannelGroup; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.threads.ThreadPoolExecutor; + +/** + * This is a utility class that enables multiple {@link WsWebSocketContainer} + * instances to share a single {@link AsynchronousChannelGroup} while ensuring + * that the group is destroyed when no longer required. + */ +public class AsyncChannelGroupUtil { + + private static final StringManager sm = + StringManager.getManager(AsyncChannelGroupUtil.class); + + private static AsynchronousChannelGroup group = null; + private static int usageCount = 0; + private static final Object lock = new Object(); + + + private AsyncChannelGroupUtil() { + // Hide the default constructor + } + + + public static AsynchronousChannelGroup register() { + synchronized (lock) { + if (usageCount == 0) { + group = createAsynchronousChannelGroup(); + } + usageCount++; + return group; + } + } + + + public static void unregister() { + synchronized (lock) { + usageCount--; + if (usageCount == 0) { + group.shutdown(); + group = null; + } + } + } + + + private static AsynchronousChannelGroup createAsynchronousChannelGroup() { + // Need to do this with the right thread context class loader else the + // first web app to call this will trigger a leak + ClassLoader original = Thread.currentThread().getContextClassLoader(); + + try { + Thread.currentThread().setContextClassLoader( + AsyncIOThreadFactory.class.getClassLoader()); + + // These are the same settings as the default + // AsynchronousChannelGroup + int initialSize = Runtime.getRuntime().availableProcessors(); + ExecutorService executorService = new ThreadPoolExecutor( + 0, + Integer.MAX_VALUE, + Long.MAX_VALUE, TimeUnit.MILLISECONDS, + new SynchronousQueue<Runnable>(), + new AsyncIOThreadFactory()); + + try { + return AsynchronousChannelGroup.withCachedThreadPool( + executorService, initialSize); + } catch (IOException e) { + // No good reason for this to happen. + throw new IllegalStateException(sm.getString("asyncChannelGroup.createFail")); + } + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + + + private static class AsyncIOThreadFactory implements ThreadFactory { + + static { + // Load NewThreadPrivilegedAction since newThread() will not be able + // to if called from an InnocuousThread. + // See https://bz.apache.org/bugzilla/show_bug.cgi?id=57490 + NewThreadPrivilegedAction.load(); + } + + + @Override + public Thread newThread(final Runnable r) { + // Create the new Thread within a doPrivileged block to ensure that + // the thread inherits the current ProtectionDomain which is + // essential to be able to use this with a Java Applet. See + // https://bz.apache.org/bugzilla/show_bug.cgi?id=57091 + return AccessController.doPrivileged(new NewThreadPrivilegedAction(r)); + } + + // Non-anonymous class so that AsyncIOThreadFactory can load it + // explicitly + private static class NewThreadPrivilegedAction implements PrivilegedAction<Thread> { + + private static AtomicInteger count = new AtomicInteger(0); + + private final Runnable r; + + public NewThreadPrivilegedAction(Runnable r) { + this.r = r; + } + + @Override + public Thread run() { + Thread t = new Thread(r); + t.setName("WebSocketClient-AsyncIO-" + count.incrementAndGet()); + t.setContextClassLoader(this.getClass().getClassLoader()); + t.setDaemon(true); + return t; + } + + private static void load() { + // NO-OP. Just provides a hook to enable the class to be loaded + } + } + } +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapper.java b/src/java/nginx/unit/websocket/AsyncChannelWrapper.java new file mode 100644 index 00000000..060ae9cb --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapper.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLException; + +/** + * This is a wrapper for a {@link java.nio.channels.AsynchronousSocketChannel} + * that limits the methods available thereby simplifying the process of + * implementing SSL/TLS support since there are fewer methods to intercept. + */ +public interface AsyncChannelWrapper { + + Future<Integer> read(ByteBuffer dst); + + <B,A extends B> void read(ByteBuffer dst, A attachment, + CompletionHandler<Integer,B> handler); + + Future<Integer> write(ByteBuffer src); + + <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler<Long,B> handler); + + void close(); + + Future<Void> handshake() throws SSLException; +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java new file mode 100644 index 00000000..5b88bfe1 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * Generally, just passes calls straight to the wrapped + * {@link AsynchronousSocketChannel}. In some cases exceptions may be swallowed + * to save them being swallowed by the calling code. + */ +public class AsyncChannelWrapperNonSecure implements AsyncChannelWrapper { + + private static final Future<Void> NOOP_FUTURE = new NoOpFuture(); + + private final AsynchronousSocketChannel socketChannel; + + public AsyncChannelWrapperNonSecure( + AsynchronousSocketChannel socketChannel) { + this.socketChannel = socketChannel; + } + + @Override + public Future<Integer> read(ByteBuffer dst) { + return socketChannel.read(dst); + } + + @Override + public <B,A extends B> void read(ByteBuffer dst, A attachment, + CompletionHandler<Integer,B> handler) { + socketChannel.read(dst, attachment, handler); + } + + @Override + public Future<Integer> write(ByteBuffer src) { + return socketChannel.write(src); + } + + @Override + public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler<Long,B> handler) { + socketChannel.write( + srcs, offset, length, timeout, unit, attachment, handler); + } + + @Override + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + // Ignore + } + } + + @Override + public Future<Void> handshake() { + return NOOP_FUTURE; + } + + + private static final class NoOpFuture implements Future<Void> { + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return true; + } + + @Override + public Void get() throws InterruptedException, ExecutionException { + return null; + } + + @Override + public Void get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + return null; + } + } +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java new file mode 100644 index 00000000..21654487 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +/** + * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot + * more testing before it can be considered robust. + */ +public class AsyncChannelWrapperSecure implements AsyncChannelWrapper { + + private final Log log = + LogFactory.getLog(AsyncChannelWrapperSecure.class); + private static final StringManager sm = + StringManager.getManager(AsyncChannelWrapperSecure.class); + + private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921); + private final AsynchronousSocketChannel socketChannel; + private final SSLEngine sslEngine; + private final ByteBuffer socketReadBuffer; + private final ByteBuffer socketWriteBuffer; + // One thread for read, one for write + private final ExecutorService executor = + Executors.newFixedThreadPool(2, new SecureIOThreadFactory()); + private AtomicBoolean writing = new AtomicBoolean(false); + private AtomicBoolean reading = new AtomicBoolean(false); + + public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, + SSLEngine sslEngine) { + this.socketChannel = socketChannel; + this.sslEngine = sslEngine; + + int socketBufferSize = sslEngine.getSession().getPacketBufferSize(); + socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize); + socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize); + } + + @Override + public Future<Integer> read(ByteBuffer dst) { + WrapperFuture<Integer,Void> future = new WrapperFuture<>(); + + if (!reading.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentRead")); + } + + ReadTask readTask = new ReadTask(dst, future); + + executor.execute(readTask); + + return future; + } + + @Override + public <B,A extends B> void read(ByteBuffer dst, A attachment, + CompletionHandler<Integer,B> handler) { + + WrapperFuture<Integer,B> future = + new WrapperFuture<>(handler, attachment); + + if (!reading.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentRead")); + } + + ReadTask readTask = new ReadTask(dst, future); + + executor.execute(readTask); + } + + @Override + public Future<Integer> write(ByteBuffer src) { + + WrapperFuture<Long,Void> inner = new WrapperFuture<>(); + + if (!writing.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentWrite")); + } + + WriteTask writeTask = + new WriteTask(new ByteBuffer[] {src}, 0, 1, inner); + + executor.execute(writeTask); + + Future<Integer> future = new LongToIntegerFuture(inner); + return future; + } + + @Override + public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler<Long,B> handler) { + + WrapperFuture<Long,B> future = + new WrapperFuture<>(handler, attachment); + + if (!writing.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentWrite")); + } + + WriteTask writeTask = new WriteTask(srcs, offset, length, future); + + executor.execute(writeTask); + } + + @Override + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + log.info(sm.getString("asyncChannelWrapperSecure.closeFail")); + } + executor.shutdownNow(); + } + + @Override + public Future<Void> handshake() throws SSLException { + + WrapperFuture<Void,Void> wFuture = new WrapperFuture<>(); + + Thread t = new WebSocketSslHandshakeThread(wFuture); + t.start(); + + return wFuture; + } + + + private class WriteTask implements Runnable { + + private final ByteBuffer[] srcs; + private final int offset; + private final int length; + private final WrapperFuture<Long,?> future; + + public WriteTask(ByteBuffer[] srcs, int offset, int length, + WrapperFuture<Long,?> future) { + this.srcs = srcs; + this.future = future; + this.offset = offset; + this.length = length; + } + + @Override + public void run() { + long written = 0; + + try { + for (int i = offset; i < offset + length; i++) { + ByteBuffer src = srcs[i]; + while (src.hasRemaining()) { + socketWriteBuffer.clear(); + + // Encrypt the data + SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer); + written += r.bytesConsumed(); + Status s = r.getStatus(); + + if (s == Status.OK || s == Status.BUFFER_OVERFLOW) { + // Need to write out the bytes and may need to read from + // the source again to empty it + } else { + // Status.BUFFER_UNDERFLOW - only happens on unwrap + // Status.CLOSED - unexpected + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.statusWrap")); + } + + // Check for tasks + if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable runnable = sslEngine.getDelegatedTask(); + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + } + + socketWriteBuffer.flip(); + + // Do the write + int toWrite = r.bytesProduced(); + while (toWrite > 0) { + Future<Integer> f = + socketChannel.write(socketWriteBuffer); + Integer socketWrite = f.get(); + toWrite -= socketWrite.intValue(); + } + } + } + + + if (writing.compareAndSet(true, false)) { + future.complete(Long.valueOf(written)); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateWrite"))); + } + } catch (Exception e) { + writing.set(false); + future.fail(e); + } + } + } + + + private class ReadTask implements Runnable { + + private final ByteBuffer dest; + private final WrapperFuture<Integer,?> future; + + public ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future) { + this.dest = dest; + this.future = future; + } + + @Override + public void run() { + int read = 0; + + boolean forceRead = false; + + try { + while (read == 0) { + socketReadBuffer.compact(); + + if (forceRead) { + forceRead = false; + Future<Integer> f = socketChannel.read(socketReadBuffer); + Integer socketRead = f.get(); + if (socketRead.intValue() == -1) { + throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof")); + } + } + + socketReadBuffer.flip(); + + if (socketReadBuffer.hasRemaining()) { + // Decrypt the data in the buffer + SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest); + read += r.bytesProduced(); + Status s = r.getStatus(); + + if (s == Status.OK) { + // Bytes available for reading and there may be + // sufficient data in the socketReadBuffer to + // support further reads without reading from the + // socket + } else if (s == Status.BUFFER_UNDERFLOW) { + // There is partial data in the socketReadBuffer + if (read == 0) { + // Need more data before the partial data can be + // processed and some output generated + forceRead = true; + } + // else return the data we have and deal with the + // partial data on the next read + } else if (s == Status.BUFFER_OVERFLOW) { + // Not enough space in the destination buffer to + // store all of the data. We could use a bytes read + // value of -bufferSizeRequired to signal the new + // buffer size required but an explicit exception is + // clearer. + if (reading.compareAndSet(true, false)) { + throw new ReadBufferOverflowException(sslEngine. + getSession().getApplicationBufferSize()); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateRead"))); + } + } else { + // Status.CLOSED - unexpected + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.statusUnwrap")); + } + + // Check for tasks + if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable runnable = sslEngine.getDelegatedTask(); + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + } + } else { + forceRead = true; + } + } + + + if (reading.compareAndSet(true, false)) { + future.complete(Integer.valueOf(read)); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateRead"))); + } + } catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException | + ExecutionException | InterruptedException e) { + reading.set(false); + future.fail(e); + } + } + } + + + private class WebSocketSslHandshakeThread extends Thread { + + private final WrapperFuture<Void,Void> hFuture; + + private HandshakeStatus handshakeStatus; + private Status resultStatus; + + public WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture) { + this.hFuture = hFuture; + } + + @Override + public void run() { + try { + sslEngine.beginHandshake(); + // So the first compact does the right thing + socketReadBuffer.position(socketReadBuffer.limit()); + + handshakeStatus = sslEngine.getHandshakeStatus(); + resultStatus = Status.OK; + + boolean handshaking = true; + + while(handshaking) { + switch (handshakeStatus) { + case NEED_WRAP: { + socketWriteBuffer.clear(); + SSLEngineResult r = + sslEngine.wrap(DUMMY, socketWriteBuffer); + checkResult(r, true); + socketWriteBuffer.flip(); + Future<Integer> fWrite = + socketChannel.write(socketWriteBuffer); + fWrite.get(); + break; + } + case NEED_UNWRAP: { + socketReadBuffer.compact(); + if (socketReadBuffer.position() == 0 || + resultStatus == Status.BUFFER_UNDERFLOW) { + Future<Integer> fRead = + socketChannel.read(socketReadBuffer); + fRead.get(); + } + socketReadBuffer.flip(); + SSLEngineResult r = + sslEngine.unwrap(socketReadBuffer, DUMMY); + checkResult(r, false); + break; + } + case NEED_TASK: { + Runnable r = null; + while ((r = sslEngine.getDelegatedTask()) != null) { + r.run(); + } + handshakeStatus = sslEngine.getHandshakeStatus(); + break; + } + case FINISHED: { + handshaking = false; + break; + } + case NOT_HANDSHAKING: { + throw new SSLException( + sm.getString("asyncChannelWrapperSecure.notHandshaking")); + } + } + } + } catch (Exception e) { + hFuture.fail(e); + return; + } + + hFuture.complete(null); + } + + private void checkResult(SSLEngineResult result, boolean wrap) + throws SSLException { + + handshakeStatus = result.getHandshakeStatus(); + resultStatus = result.getStatus(); + + if (resultStatus != Status.OK && + (wrap || resultStatus != Status.BUFFER_UNDERFLOW)) { + throw new SSLException( + sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus)); + } + if (wrap && result.bytesConsumed() != 0) { + throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap")); + } + if (!wrap && result.bytesProduced() != 0) { + throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap")); + } + } + } + + + private static class WrapperFuture<T,A> implements Future<T> { + + private final CompletionHandler<T,A> handler; + private final A attachment; + + private volatile T result = null; + private volatile Throwable throwable = null; + private CountDownLatch completionLatch = new CountDownLatch(1); + + public WrapperFuture() { + this(null, null); + } + + public WrapperFuture(CompletionHandler<T,A> handler, A attachment) { + this.handler = handler; + this.attachment = attachment; + } + + public void complete(T result) { + this.result = result; + completionLatch.countDown(); + if (handler != null) { + handler.completed(result, attachment); + } + } + + public void fail(Throwable t) { + throwable = t; + completionLatch.countDown(); + if (handler != null) { + handler.failed(throwable, attachment); + } + } + + @Override + public final boolean cancel(boolean mayInterruptIfRunning) { + // Could support cancellation by closing the connection + return false; + } + + @Override + public final boolean isCancelled() { + // Could support cancellation by closing the connection + return false; + } + + @Override + public final boolean isDone() { + return completionLatch.getCount() > 0; + } + + @Override + public T get() throws InterruptedException, ExecutionException { + completionLatch.await(); + if (throwable != null) { + throw new ExecutionException(throwable); + } + return result; + } + + @Override + public T get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + boolean latchResult = completionLatch.await(timeout, unit); + if (latchResult == false) { + throw new TimeoutException(); + } + if (throwable != null) { + throw new ExecutionException(throwable); + } + return result; + } + } + + private static final class LongToIntegerFuture implements Future<Integer> { + + private final Future<Long> wrapped; + + public LongToIntegerFuture(Future<Long> wrapped) { + this.wrapped = wrapped; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return wrapped.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return wrapped.isCancelled(); + } + + @Override + public boolean isDone() { + return wrapped.isDone(); + } + + @Override + public Integer get() throws InterruptedException, ExecutionException { + Long result = wrapped.get(); + if (result.longValue() > Integer.MAX_VALUE) { + throw new ExecutionException(sm.getString( + "asyncChannelWrapperSecure.tooBig", result), null); + } + return Integer.valueOf(result.intValue()); + } + + @Override + public Integer get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + Long result = wrapped.get(timeout, unit); + if (result.longValue() > Integer.MAX_VALUE) { + throw new ExecutionException(sm.getString( + "asyncChannelWrapperSecure.tooBig", result), null); + } + return Integer.valueOf(result.intValue()); + } + } + + + private static class SecureIOThreadFactory implements ThreadFactory { + + private AtomicInteger count = new AtomicInteger(0); + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet()); + // No need to set the context class loader. The threads will be + // cleaned up when the connection is closed. + t.setDaemon(true); + return t; + } + } +} diff --git a/src/java/nginx/unit/websocket/AuthenticationException.java b/src/java/nginx/unit/websocket/AuthenticationException.java new file mode 100644 index 00000000..001f1829 --- /dev/null +++ b/src/java/nginx/unit/websocket/AuthenticationException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +/** + * Exception thrown on authentication error connecting to a remote + * websocket endpoint. + */ +public class AuthenticationException extends Exception { + + private static final long serialVersionUID = 5709887412240096441L; + + /** + * Create authentication exception. + * @param message the error message + */ + public AuthenticationException(String message) { + super(message); + } + +} diff --git a/src/java/nginx/unit/websocket/Authenticator.java b/src/java/nginx/unit/websocket/Authenticator.java new file mode 100644 index 00000000..87b3ce6d --- /dev/null +++ b/src/java/nginx/unit/websocket/Authenticator.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Base class for the authentication methods used by the websocket client. + */ +public abstract class Authenticator { + private static final Pattern pattern = Pattern + .compile("(\\w+)\\s*=\\s*(\"([^\"]+)\"|([^,=\"]+))\\s*,?"); + + /** + * Generate the authentication header that will be sent to the server. + * @param requestUri The request URI + * @param WWWAuthenticate The server auth challenge + * @param UserProperties The user information + * @return The auth header + * @throws AuthenticationException When an error occurs + */ + public abstract String getAuthorization(String requestUri, String WWWAuthenticate, + Map<String, Object> UserProperties) throws AuthenticationException; + + /** + * Get the authentication method. + * @return the auth scheme + */ + public abstract String getSchemeName(); + + /** + * Utility method to parse the authentication header. + * @param WWWAuthenticate The server auth challenge + * @return the parsed header + */ + public Map<String, String> parseWWWAuthenticateHeader(String WWWAuthenticate) { + + Matcher m = pattern.matcher(WWWAuthenticate); + Map<String, String> challenge = new HashMap<>(); + + while (m.find()) { + String key = m.group(1); + String qtedValue = m.group(3); + String value = m.group(4); + + challenge.put(key, qtedValue != null ? qtedValue : value); + + } + + return challenge; + + } + +} diff --git a/src/java/nginx/unit/websocket/AuthenticatorFactory.java b/src/java/nginx/unit/websocket/AuthenticatorFactory.java new file mode 100644 index 00000000..7d46d7f9 --- /dev/null +++ b/src/java/nginx/unit/websocket/AuthenticatorFactory.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.Iterator; +import java.util.ServiceLoader; + +/** + * Utility method to return the appropriate authenticator according to + * the scheme that the server uses. + */ +public class AuthenticatorFactory { + + /** + * Return a new authenticator instance. + * @param authScheme The scheme used + * @return the authenticator + */ + public static Authenticator getAuthenticator(String authScheme) { + + Authenticator auth = null; + switch (authScheme.toLowerCase()) { + + case BasicAuthenticator.schemeName: + auth = new BasicAuthenticator(); + break; + + case DigestAuthenticator.schemeName: + auth = new DigestAuthenticator(); + break; + + default: + auth = loadAuthenticators(authScheme); + break; + } + + return auth; + + } + + private static Authenticator loadAuthenticators(String authScheme) { + ServiceLoader<Authenticator> serviceLoader = ServiceLoader.load(Authenticator.class); + Iterator<Authenticator> auths = serviceLoader.iterator(); + + while (auths.hasNext()) { + Authenticator auth = auths.next(); + if (auth.getSchemeName().equalsIgnoreCase(authScheme)) + return auth; + } + + return null; + } + +} diff --git a/src/java/nginx/unit/websocket/BackgroundProcess.java b/src/java/nginx/unit/websocket/BackgroundProcess.java new file mode 100644 index 00000000..0d2e1288 --- /dev/null +++ b/src/java/nginx/unit/websocket/BackgroundProcess.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public interface BackgroundProcess { + + void backgroundProcess(); + + void setProcessPeriod(int period); + + int getProcessPeriod(); +} diff --git a/src/java/nginx/unit/websocket/BackgroundProcessManager.java b/src/java/nginx/unit/websocket/BackgroundProcessManager.java new file mode 100644 index 00000000..d8b1b950 --- /dev/null +++ b/src/java/nginx/unit/websocket/BackgroundProcessManager.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.HashSet; +import java.util.Set; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.res.StringManager; + +/** + * Provides a background processing mechanism that triggers roughly once a + * second. The class maintains a thread that only runs when there is at least + * one instance of {@link BackgroundProcess} registered. + */ +public class BackgroundProcessManager { + + private final Log log = + LogFactory.getLog(BackgroundProcessManager.class); + private static final StringManager sm = + StringManager.getManager(BackgroundProcessManager.class); + private static final BackgroundProcessManager instance; + + + static { + instance = new BackgroundProcessManager(); + } + + + public static BackgroundProcessManager getInstance() { + return instance; + } + + private final Set<BackgroundProcess> processes = new HashSet<>(); + private final Object processesLock = new Object(); + private WsBackgroundThread wsBackgroundThread = null; + + private BackgroundProcessManager() { + // Hide default constructor + } + + + public void register(BackgroundProcess process) { + synchronized (processesLock) { + if (processes.size() == 0) { + wsBackgroundThread = new WsBackgroundThread(this); + wsBackgroundThread.setContextClassLoader( + this.getClass().getClassLoader()); + wsBackgroundThread.setDaemon(true); + wsBackgroundThread.start(); + } + processes.add(process); + } + } + + + public void unregister(BackgroundProcess process) { + synchronized (processesLock) { + processes.remove(process); + if (wsBackgroundThread != null && processes.size() == 0) { + wsBackgroundThread.halt(); + wsBackgroundThread = null; + } + } + } + + + private void process() { + Set<BackgroundProcess> currentProcesses = new HashSet<>(); + synchronized (processesLock) { + currentProcesses.addAll(processes); + } + for (BackgroundProcess process : currentProcesses) { + try { + process.backgroundProcess(); + } catch (Throwable t) { + ExceptionUtils.handleThrowable(t); + log.error(sm.getString( + "backgroundProcessManager.processFailed"), t); + } + } + } + + + /* + * For unit testing. + */ + int getProcessCount() { + synchronized (processesLock) { + return processes.size(); + } + } + + + void shutdown() { + synchronized (processesLock) { + processes.clear(); + if (wsBackgroundThread != null) { + wsBackgroundThread.halt(); + wsBackgroundThread = null; + } + } + } + + + private static class WsBackgroundThread extends Thread { + + private final BackgroundProcessManager manager; + private volatile boolean running = true; + + public WsBackgroundThread(BackgroundProcessManager manager) { + setName("WebSocket background processing"); + this.manager = manager; + } + + @Override + public void run() { + while (running) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + // Ignore + } + manager.process(); + } + } + + public void halt() { + setName("WebSocket background processing - stopping"); + running = false; + } + } +} diff --git a/src/java/nginx/unit/websocket/BasicAuthenticator.java b/src/java/nginx/unit/websocket/BasicAuthenticator.java new file mode 100644 index 00000000..1b1a6b83 --- /dev/null +++ b/src/java/nginx/unit/websocket/BasicAuthenticator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Map; + +/** + * Authenticator supporting the BASIC auth method. + */ +public class BasicAuthenticator extends Authenticator { + + public static final String schemeName = "basic"; + public static final String charsetparam = "charset"; + + @Override + public String getAuthorization(String requestUri, String WWWAuthenticate, + Map<String, Object> userProperties) throws AuthenticationException { + + String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME); + String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD); + + if (userName == null || password == null) { + throw new AuthenticationException( + "Failed to perform Basic authentication due to missing user/password"); + } + + Map<String, String> wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate); + + String userPass = userName + ":" + password; + Charset charset; + + if (wwwAuthenticate.get(charsetparam) != null + && wwwAuthenticate.get(charsetparam).equalsIgnoreCase("UTF-8")) { + charset = StandardCharsets.UTF_8; + } else { + charset = StandardCharsets.ISO_8859_1; + } + + String base64 = Base64.getEncoder().encodeToString(userPass.getBytes(charset)); + + return " Basic " + base64; + } + + @Override + public String getSchemeName() { + return schemeName; + } + +} diff --git a/src/java/nginx/unit/websocket/Constants.java b/src/java/nginx/unit/websocket/Constants.java new file mode 100644 index 00000000..38b22fe0 --- /dev/null +++ b/src/java/nginx/unit/websocket/Constants.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import javax.websocket.Extension; + +/** + * Internal implementation constants. + */ +public class Constants { + + // OP Codes + public static final byte OPCODE_CONTINUATION = 0x00; + public static final byte OPCODE_TEXT = 0x01; + public static final byte OPCODE_BINARY = 0x02; + public static final byte OPCODE_CLOSE = 0x08; + public static final byte OPCODE_PING = 0x09; + public static final byte OPCODE_PONG = 0x0A; + + // Internal OP Codes + // RFC 6455 limits OP Codes to 4 bits so these should never clash + // Always set bit 4 so these will be treated as control codes + static final byte INTERNAL_OPCODE_FLUSH = 0x18; + + // Buffers + static final int DEFAULT_BUFFER_SIZE = Integer.getInteger( + "nginx.unit.websocket.DEFAULT_BUFFER_SIZE", 8 * 1024) + .intValue(); + + // Client connection + /** + * Property name to set to configure the value that is passed to + * {@link javax.net.ssl.SSLEngine#setEnabledProtocols(String[])}. The value + * should be a comma separated string. + */ + public static final String SSL_PROTOCOLS_PROPERTY = + "nginx.unit.websocket.SSL_PROTOCOLS"; + public static final String SSL_TRUSTSTORE_PROPERTY = + "nginx.unit.websocket.SSL_TRUSTSTORE"; + public static final String SSL_TRUSTSTORE_PWD_PROPERTY = + "nginx.unit.websocket.SSL_TRUSTSTORE_PWD"; + public static final String SSL_TRUSTSTORE_PWD_DEFAULT = "changeit"; + /** + * Property name to set to configure used SSLContext. The value should be an + * instance of SSLContext. If this property is present, the SSL_TRUSTSTORE* + * properties are ignored. + */ + public static final String SSL_CONTEXT_PROPERTY = + "nginx.unit.websocket.SSL_CONTEXT"; + /** + * Property name to set to configure the timeout (in milliseconds) when + * establishing a WebSocket connection to server. The default is + * {@link #IO_TIMEOUT_MS_DEFAULT}. + */ + public static final String IO_TIMEOUT_MS_PROPERTY = + "nginx.unit.websocket.IO_TIMEOUT_MS"; + public static final long IO_TIMEOUT_MS_DEFAULT = 5000; + + // RFC 2068 recommended a limit of 5 + // Most browsers have a default limit of 20 + public static final String MAX_REDIRECTIONS_PROPERTY = + "nginx.unit.websocket.MAX_REDIRECTIONS"; + public static final int MAX_REDIRECTIONS_DEFAULT = 20; + + // HTTP upgrade header names and values + public static final String HOST_HEADER_NAME = "Host"; + public static final String UPGRADE_HEADER_NAME = "Upgrade"; + public static final String UPGRADE_HEADER_VALUE = "websocket"; + public static final String ORIGIN_HEADER_NAME = "Origin"; + public static final String CONNECTION_HEADER_NAME = "Connection"; + public static final String CONNECTION_HEADER_VALUE = "upgrade"; + public static final String LOCATION_HEADER_NAME = "Location"; + public static final String AUTHORIZATION_HEADER_NAME = "Authorization"; + public static final String WWW_AUTHENTICATE_HEADER_NAME = "WWW-Authenticate"; + public static final String WS_VERSION_HEADER_NAME = "Sec-WebSocket-Version"; + public static final String WS_VERSION_HEADER_VALUE = "13"; + public static final String WS_KEY_HEADER_NAME = "Sec-WebSocket-Key"; + public static final String WS_PROTOCOL_HEADER_NAME = "Sec-WebSocket-Protocol"; + public static final String WS_EXTENSIONS_HEADER_NAME = "Sec-WebSocket-Extensions"; + + /// HTTP redirection status codes + public static final int MULTIPLE_CHOICES = 300; + public static final int MOVED_PERMANENTLY = 301; + public static final int FOUND = 302; + public static final int SEE_OTHER = 303; + public static final int USE_PROXY = 305; + public static final int TEMPORARY_REDIRECT = 307; + + // Configuration for Origin header in client + static final String DEFAULT_ORIGIN_HEADER_VALUE = + System.getProperty("nginx.unit.websocket.DEFAULT_ORIGIN_HEADER_VALUE"); + + // Configuration for blocking sends + public static final String BLOCKING_SEND_TIMEOUT_PROPERTY = + "nginx.unit.websocket.BLOCKING_SEND_TIMEOUT"; + // Milliseconds so this is 20 seconds + public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000; + + // Configuration for background processing checks intervals + static final int DEFAULT_PROCESS_PERIOD = Integer.getInteger( + "nginx.unit.websocket.DEFAULT_PROCESS_PERIOD", 10) + .intValue(); + + public static final String WS_AUTHENTICATION_USER_NAME = "nginx.unit.websocket.WS_AUTHENTICATION_USER_NAME"; + public static final String WS_AUTHENTICATION_PASSWORD = "nginx.unit.websocket.WS_AUTHENTICATION_PASSWORD"; + + /* Configuration for extensions + * Note: These options are primarily present to enable this implementation + * to pass compliance tests. They are expected to be removed once + * the WebSocket API includes a mechanism for adding custom extensions + * and disabling built-in extensions. + */ + static final boolean DISABLE_BUILTIN_EXTENSIONS = + Boolean.getBoolean("nginx.unit.websocket.DISABLE_BUILTIN_EXTENSIONS"); + static final boolean ALLOW_UNSUPPORTED_EXTENSIONS = + Boolean.getBoolean("nginx.unit.websocket.ALLOW_UNSUPPORTED_EXTENSIONS"); + + // Configuration for stream behavior + static final boolean STREAMS_DROP_EMPTY_MESSAGES = + Boolean.getBoolean("nginx.unit.websocket.STREAMS_DROP_EMPTY_MESSAGES"); + + public static final boolean STRICT_SPEC_COMPLIANCE = + Boolean.getBoolean("nginx.unit.websocket.STRICT_SPEC_COMPLIANCE"); + + public static final List<Extension> INSTALLED_EXTENSIONS; + + static { + if (DISABLE_BUILTIN_EXTENSIONS) { + INSTALLED_EXTENSIONS = Collections.unmodifiableList(new ArrayList<Extension>()); + } else { + List<Extension> installed = new ArrayList<>(1); + installed.add(new WsExtension("permessage-deflate")); + INSTALLED_EXTENSIONS = Collections.unmodifiableList(installed); + } + } + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/DecoderEntry.java b/src/java/nginx/unit/websocket/DecoderEntry.java new file mode 100644 index 00000000..36112ef4 --- /dev/null +++ b/src/java/nginx/unit/websocket/DecoderEntry.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.Decoder; + +public class DecoderEntry { + + private final Class<?> clazz; + private final Class<? extends Decoder> decoderClazz; + + public DecoderEntry(Class<?> clazz, + Class<? extends Decoder> decoderClazz) { + this.clazz = clazz; + this.decoderClazz = decoderClazz; + } + + public Class<?> getClazz() { + return clazz; + } + + public Class<? extends Decoder> getDecoderClazz() { + return decoderClazz; + } +} diff --git a/src/java/nginx/unit/websocket/DigestAuthenticator.java b/src/java/nginx/unit/websocket/DigestAuthenticator.java new file mode 100644 index 00000000..9530c303 --- /dev/null +++ b/src/java/nginx/unit/websocket/DigestAuthenticator.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.Map; + +import org.apache.tomcat.util.security.MD5Encoder; + +/** + * Authenticator supporting the DIGEST auth method. + */ +public class DigestAuthenticator extends Authenticator { + + public static final String schemeName = "digest"; + private SecureRandom cnonceGenerator; + private int nonceCount = 0; + private long cNonce; + + @Override + public String getAuthorization(String requestUri, String WWWAuthenticate, + Map<String, Object> userProperties) throws AuthenticationException { + + String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME); + String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD); + + if (userName == null || password == null) { + throw new AuthenticationException( + "Failed to perform Digest authentication due to missing user/password"); + } + + Map<String, String> wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate); + + String realm = wwwAuthenticate.get("realm"); + String nonce = wwwAuthenticate.get("nonce"); + String messageQop = wwwAuthenticate.get("qop"); + String algorithm = wwwAuthenticate.get("algorithm") == null ? "MD5" + : wwwAuthenticate.get("algorithm"); + String opaque = wwwAuthenticate.get("opaque"); + + StringBuilder challenge = new StringBuilder(); + + if (!messageQop.isEmpty()) { + if (cnonceGenerator == null) { + cnonceGenerator = new SecureRandom(); + } + + cNonce = cnonceGenerator.nextLong(); + nonceCount++; + } + + challenge.append("Digest "); + challenge.append("username =\"" + userName + "\","); + challenge.append("realm=\"" + realm + "\","); + challenge.append("nonce=\"" + nonce + "\","); + challenge.append("uri=\"" + requestUri + "\","); + + try { + challenge.append("response=\"" + calculateRequestDigest(requestUri, userName, password, + realm, nonce, messageQop, algorithm) + "\","); + } + + catch (NoSuchAlgorithmException e) { + throw new AuthenticationException( + "Unable to generate request digest " + e.getMessage()); + } + + challenge.append("algorithm=" + algorithm + ","); + challenge.append("opaque=\"" + opaque + "\","); + + if (!messageQop.isEmpty()) { + challenge.append("qop=\"" + messageQop + "\""); + challenge.append(",cnonce=\"" + cNonce + "\","); + challenge.append("nc=" + String.format("%08X", Integer.valueOf(nonceCount))); + } + + return challenge.toString(); + + } + + private String calculateRequestDigest(String requestUri, String userName, String password, + String realm, String nonce, String qop, String algorithm) + throws NoSuchAlgorithmException { + + StringBuilder preDigest = new StringBuilder(); + String A1; + + if (algorithm.equalsIgnoreCase("MD5")) + A1 = userName + ":" + realm + ":" + password; + + else + A1 = encodeMD5(userName + ":" + realm + ":" + password) + ":" + nonce + ":" + cNonce; + + /* + * If the "qop" value is "auth-int", then A2 is: A2 = Method ":" + * digest-uri-value ":" H(entity-body) since we do not have an entity-body, A2 = + * Method ":" digest-uri-value for auth and auth_int + */ + String A2 = "GET:" + requestUri; + + preDigest.append(encodeMD5(A1)); + preDigest.append(":"); + preDigest.append(nonce); + + if (qop.toLowerCase().contains("auth")) { + preDigest.append(":"); + preDigest.append(String.format("%08X", Integer.valueOf(nonceCount))); + preDigest.append(":"); + preDigest.append(String.valueOf(cNonce)); + preDigest.append(":"); + preDigest.append(qop); + } + + preDigest.append(":"); + preDigest.append(encodeMD5(A2)); + + return encodeMD5(preDigest.toString()); + + } + + private String encodeMD5(String value) throws NoSuchAlgorithmException { + byte[] bytesOfMessage = value.getBytes(StandardCharsets.ISO_8859_1); + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] thedigest = md.digest(bytesOfMessage); + + return MD5Encoder.encode(thedigest); + } + + @Override + public String getSchemeName() { + return schemeName; + } +} diff --git a/src/java/nginx/unit/websocket/FutureToSendHandler.java b/src/java/nginx/unit/websocket/FutureToSendHandler.java new file mode 100644 index 00000000..4a0809cb --- /dev/null +++ b/src/java/nginx/unit/websocket/FutureToSendHandler.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.tomcat.util.res.StringManager; + + +/** + * Converts a Future to a SendHandler. + */ +class FutureToSendHandler implements Future<Void>, SendHandler { + + private static final StringManager sm = StringManager.getManager(FutureToSendHandler.class); + + private final CountDownLatch latch = new CountDownLatch(1); + private final WsSession wsSession; + private volatile AtomicReference<SendResult> result = new AtomicReference<>(null); + + public FutureToSendHandler(WsSession wsSession) { + this.wsSession = wsSession; + } + + + // --------------------------------------------------------- SendHandler + + @Override + public void onResult(SendResult result) { + this.result.compareAndSet(null, result); + latch.countDown(); + } + + + // -------------------------------------------------------------- Future + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + // Cancelling the task is not supported + return false; + } + + @Override + public boolean isCancelled() { + // Cancelling the task is not supported + return false; + } + + @Override + public boolean isDone() { + return latch.getCount() == 0; + } + + @Override + public Void get() throws InterruptedException, + ExecutionException { + try { + wsSession.registerFuture(this); + latch.await(); + } finally { + wsSession.unregisterFuture(this); + } + if (result.get().getException() != null) { + throw new ExecutionException(result.get().getException()); + } + return null; + } + + @Override + public Void get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + boolean retval = false; + try { + wsSession.registerFuture(this); + retval = latch.await(timeout, unit); + } finally { + wsSession.unregisterFuture(this); + + } + if (retval == false) { + throw new TimeoutException(sm.getString("futureToSendHandler.timeout", + Long.valueOf(timeout), unit.toString().toLowerCase())); + } + if (result.get().getException() != null) { + throw new ExecutionException(result.get().getException()); + } + return null; + } +} diff --git a/src/java/nginx/unit/websocket/LocalStrings.properties b/src/java/nginx/unit/websocket/LocalStrings.properties new file mode 100644 index 00000000..aeafe082 --- /dev/null +++ b/src/java/nginx/unit/websocket/LocalStrings.properties @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +asyncChannelGroup.createFail=Unable to create dedicated AsynchronousChannelGroup for WebSocket clients which is required to prevent memory leaks in complex class loader environments like JavaEE containers + +asyncChannelWrapperSecure.closeFail=Failed to close channel cleanly +asyncChannelWrapperSecure.check.notOk=TLS handshake returned an unexpected status [{0}] +asyncChannelWrapperSecure.check.unwrap=Bytes were written to the output during a read +asyncChannelWrapperSecure.check.wrap=Bytes were consumed from the input during a write +asyncChannelWrapperSecure.concurrentRead=Concurrent read operations are not permitted +asyncChannelWrapperSecure.concurrentWrite=Concurrent write operations are not permitted +asyncChannelWrapperSecure.eof=Unexpected end of stream +asyncChannelWrapperSecure.notHandshaking=Unexpected state [NOT_HANDSHAKING] during TLS handshake +asyncChannelWrapperSecure.readOverflow=Buffer overflow. [{0}] bytes to write into a [{1}] byte buffer that already contained [{2}] bytes. +asyncChannelWrapperSecure.statusUnwrap=Unexpected Status of SSLEngineResult after an unwrap() operation +asyncChannelWrapperSecure.statusWrap=Unexpected Status of SSLEngineResult after a wrap() operation +asyncChannelWrapperSecure.tooBig=The result [{0}] is too big to be expressed as an Integer +asyncChannelWrapperSecure.wrongStateRead=Flag that indicates a read is in progress was found to be false (it should have been true) when trying to complete a read operation +asyncChannelWrapperSecure.wrongStateWrite=Flag that indicates a write is in progress was found to be false (it should have been true) when trying to complete a write operation + +backgroundProcessManager.processFailed=A background process failed + +caseInsensitiveKeyMap.nullKey=Null keys are not permitted + +futureToSendHandler.timeout=Operation timed out after waiting [{0}] [{1}] to complete + +perMessageDeflate.deflateFailed=Failed to decompress a compressed WebSocket frame +perMessageDeflate.duplicateParameter=Duplicate definition of the [{0}] extension parameter +perMessageDeflate.invalidWindowSize=An invalid windows of [{1}] size was specified for [{0}]. Valid values are whole numbers from 8 to 15 inclusive. +perMessageDeflate.unknownParameter=An unknown extension parameter [{0}] was defined + +transformerFactory.unsupportedExtension=The extension [{0}] is not supported + +util.notToken=An illegal extension parameter was specified with name [{0}] and value [{1}] +util.invalidMessageHandler=The message handler provided does not have an onMessage(Object) method +util.invalidType=Unable to coerce value [{0}] to type [{1}]. That type is not supported. +util.unknownDecoderType=The Decoder type [{0}] is not recognized + +# Note the wsFrame.* messages are used as close reasons in WebSocket control +# frames and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsFrame.alreadyResumed=Message receiving has already been resumed. +wsFrame.alreadySuspended=Message receiving has already been suspended. +wsFrame.bufferTooSmall=No async message support and buffer too small. Buffer size: [{0}], Message size: [{1}] +wsFrame.byteToLongFail=Too many bytes ([{0}]) were provided to be converted into a long +wsFrame.closed=New frame received after a close control frame +wsFrame.controlFragmented=A fragmented control frame was received but control frames may not be fragmented +wsFrame.controlPayloadTooBig=A control frame was sent with a payload of size [{0}] which is larger than the maximum permitted of 125 bytes +wsFrame.controlNoFin=A control frame was sent that did not have the fin bit set. Control frames are not permitted to use continuation frames. +wsFrame.illegalReadState=Unexpected read state [{0}] +wsFrame.invalidOpCode= A WebSocket frame was sent with an unrecognised opCode of [{0}] +wsFrame.invalidUtf8=A WebSocket text frame was received that could not be decoded to UTF-8 because it contained invalid byte sequences +wsFrame.invalidUtf8Close=A WebSocket close frame was received with a close reason that contained invalid UTF-8 byte sequences +wsFrame.ioeTriggeredClose=An unrecoverable IOException occurred so the connection was closed +wsFrame.messageTooBig=The message was [{0}] bytes long but the MessageHandler has a limit of [{1}] bytes +wsFrame.noContinuation=A new message was started when a continuation frame was expected +wsFrame.notMasked=The client frame was not masked but all client frames must be masked +wsFrame.oneByteCloseCode=The client sent a close frame with a single byte payload which is not valid +wsFrame.partialHeaderComplete=WebSocket frame received. fin [{0}], rsv [{1}], OpCode [{2}], payload length [{3}] +wsFrame.sessionClosed=The client data cannot be processed because the session has already been closed +wsFrame.suspendRequested=Suspend of the message receiving has already been requested. +wsFrame.textMessageTooBig=The decoded text message was too big for the output buffer and the endpoint does not support partial messages +wsFrame.wrongRsv=The client frame set the reserved bits to [{0}] for a message with opCode [{1}] which was not supported by this endpoint + +wsFrameClient.ioe=Failure while reading data sent by server + +wsHandshakeRequest.invalidUri=The string [{0}] cannot be used to construct a valid URI +wsHandshakeRequest.unknownScheme=The scheme [{0}] in the request is not recognised + +wsRemoteEndpoint.acquireTimeout=The current message was not fully sent within the specified timeout +wsRemoteEndpoint.closed=Message will not be sent because the WebSocket session has been closed +wsRemoteEndpoint.closedDuringMessage=The remainder of the message will not be sent because the WebSocket session has been closed +wsRemoteEndpoint.closedOutputStream=This method may not be called as the OutputStream has been closed +wsRemoteEndpoint.closedWriter=This method may not be called as the Writer has been closed +wsRemoteEndpoint.changeType=When sending a fragmented message, all fragments must be of the same type +wsRemoteEndpoint.concurrentMessageSend=Messages may not be sent concurrently even when using the asynchronous send messages. The client must wait for the previous message to complete before sending the next. +wsRemoteEndpoint.flushOnCloseFailed=Batched messages still enabled after session has been closed. Unable to flush remaining batched message. +wsRemoteEndpoint.invalidEncoder=The specified encoder of type [{0}] could not be instantiated +wsRemoteEndpoint.noEncoder=No encoder specified for object of class [{0}] +wsRemoteEndpoint.nullData=Invalid null data argument +wsRemoteEndpoint.nullHandler=Invalid null handler argument +wsRemoteEndpoint.sendInterrupt=The current thread was interrupted while waiting for a blocking send to complete +wsRemoteEndpoint.tooMuchData=Ping or pong may not send more than 125 bytes +wsRemoteEndpoint.wrongState=The remote endpoint was in state [{0}] which is an invalid state for called method + +# Note the following message is used as a close reason in a WebSocket control +# frame and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsSession.timeout=The WebSocket session [{0}] timeout expired + +wsSession.closed=The WebSocket session [{0}] has been closed and no method (apart from close()) may be called on a closed session +wsSession.created=Created WebSocket session [{0}] +wsSession.doClose=Closing WebSocket session [{1}] +wsSession.duplicateHandlerBinary=A binary message handler has already been configured +wsSession.duplicateHandlerPong=A pong message handler has already been configured +wsSession.duplicateHandlerText=A text message handler has already been configured +wsSession.invalidHandlerTypePong=A pong message handler must implement MessageHandler.Whole +wsSession.flushFailOnClose=Failed to flush batched messages on session close +wsSession.messageFailed=Unable to write the complete message as the WebSocket connection has been closed +wsSession.sendCloseFail=Failed to send close message for session [{0}] to remote endpoint +wsSession.removeHandlerFailed=Unable to remove the handler [{0}] as it was not registered with this session +wsSession.unknownHandler=Unable to add the message handler [{0}] as it was for the unrecognised type [{1}] +wsSession.unknownHandlerType=Unable to add the message handler [{0}] as it was wrapped as the unrecognised type [{1}] +wsSession.instanceNew=Endpoint instance registration failed +wsSession.instanceDestroy=Endpoint instance unregistration failed + +# Note the following message is used as a close reason in a WebSocket control +# frame and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsWebSocketContainer.shutdown=The web application is stopping + +wsWebSocketContainer.defaultConfiguratorFail=Failed to create the default configurator +wsWebSocketContainer.endpointCreateFail=Failed to create a local endpoint of type [{0}] +wsWebSocketContainer.maxBuffer=This implementation limits the maximum size of a buffer to Integer.MAX_VALUE +wsWebSocketContainer.missingAnnotation=Cannot use POJO class [{0}] as it is not annotated with @ClientEndpoint +wsWebSocketContainer.sessionCloseFail=Session with ID [{0}] did not close cleanly + +wsWebSocketContainer.asynchronousSocketChannelFail=Unable to open a connection to the server +wsWebSocketContainer.httpRequestFailed=The HTTP request to initiate the WebSocket connection failed +wsWebSocketContainer.invalidExtensionParameters=The server responded with extension parameters the client is unable to support +wsWebSocketContainer.invalidHeader=Unable to parse HTTP header as no colon is present to delimit header name and header value in [{0}]. The header has been skipped. +wsWebSocketContainer.invalidStatus=The HTTP response from the server [{0}] did not permit the HTTP upgrade to WebSocket +wsWebSocketContainer.invalidSubProtocol=The WebSocket server returned multiple values for the Sec-WebSocket-Protocol header +wsWebSocketContainer.pathNoHost=No host was specified in URI +wsWebSocketContainer.pathWrongScheme=The scheme [{0}] is not supported. The supported schemes are ws and wss +wsWebSocketContainer.proxyConnectFail=Failed to connect to the configured Proxy [{0}]. The HTTP response code was [{1}] +wsWebSocketContainer.sslEngineFail=Unable to create SSLEngine to support SSL/TLS connections +wsWebSocketContainer.missingLocationHeader=Failed to handle HTTP response code [{0}]. Missing Location header in response +wsWebSocketContainer.redirectThreshold=Cyclic Location header [{0}] detected / reached max number of redirects [{1}] of max [{2}] +wsWebSocketContainer.unsupportedAuthScheme=Failed to handle HTTP response code [{0}]. Unsupported Authentication scheme [{1}] returned in response +wsWebSocketContainer.failedAuthentication=Failed to handle HTTP response code [{0}]. Authentication header was not accepted by server. +wsWebSocketContainer.missingWWWAuthenticateHeader=Failed to handle HTTP response code [{0}]. Missing WWW-Authenticate header in response diff --git a/src/java/nginx/unit/websocket/MessageHandlerResult.java b/src/java/nginx/unit/websocket/MessageHandlerResult.java new file mode 100644 index 00000000..8d532d1e --- /dev/null +++ b/src/java/nginx/unit/websocket/MessageHandlerResult.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.MessageHandler; + +public class MessageHandlerResult { + + private final MessageHandler handler; + private final MessageHandlerResultType type; + + + public MessageHandlerResult(MessageHandler handler, + MessageHandlerResultType type) { + this.handler = handler; + this.type = type; + } + + + public MessageHandler getHandler() { + return handler; + } + + + public MessageHandlerResultType getType() { + return type; + } +} diff --git a/src/java/nginx/unit/websocket/MessageHandlerResultType.java b/src/java/nginx/unit/websocket/MessageHandlerResultType.java new file mode 100644 index 00000000..1961bb4f --- /dev/null +++ b/src/java/nginx/unit/websocket/MessageHandlerResultType.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public enum MessageHandlerResultType { + BINARY, + TEXT, + PONG +} diff --git a/src/java/nginx/unit/websocket/MessagePart.java b/src/java/nginx/unit/websocket/MessagePart.java new file mode 100644 index 00000000..b52c26f1 --- /dev/null +++ b/src/java/nginx/unit/websocket/MessagePart.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; + +import javax.websocket.SendHandler; + +class MessagePart { + private final boolean fin; + private final int rsv; + private final byte opCode; + private final ByteBuffer payload; + private final SendHandler intermediateHandler; + private volatile SendHandler endHandler; + private final long blockingWriteTimeoutExpiry; + + public MessagePart( boolean fin, int rsv, byte opCode, ByteBuffer payload, + SendHandler intermediateHandler, SendHandler endHandler, + long blockingWriteTimeoutExpiry) { + this.fin = fin; + this.rsv = rsv; + this.opCode = opCode; + this.payload = payload; + this.intermediateHandler = intermediateHandler; + this.endHandler = endHandler; + this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry; + } + + + public boolean isFin() { + return fin; + } + + + public int getRsv() { + return rsv; + } + + + public byte getOpCode() { + return opCode; + } + + + public ByteBuffer getPayload() { + return payload; + } + + + public SendHandler getIntermediateHandler() { + return intermediateHandler; + } + + + public SendHandler getEndHandler() { + return endHandler; + } + + public void setEndHandler(SendHandler endHandler) { + this.endHandler = endHandler; + } + + public long getBlockingWriteTimeoutExpiry() { + return blockingWriteTimeoutExpiry; + } +} + + diff --git a/src/java/nginx/unit/websocket/PerMessageDeflate.java b/src/java/nginx/unit/websocket/PerMessageDeflate.java new file mode 100644 index 00000000..88e0a0bc --- /dev/null +++ b/src/java/nginx/unit/websocket/PerMessageDeflate.java @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; + +import javax.websocket.Extension; +import javax.websocket.Extension.Parameter; +import javax.websocket.SendHandler; + +import org.apache.tomcat.util.res.StringManager; + +public class PerMessageDeflate implements Transformation { + + private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class); + + private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover"; + private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover"; + private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits"; + private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits"; + + private static final int RSV_BITMASK = 0b100; + private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1}; + + public static final String NAME = "permessage-deflate"; + + private final boolean serverContextTakeover; + private final int serverMaxWindowBits; + private final boolean clientContextTakeover; + private final int clientMaxWindowBits; + private final boolean isServer; + private final Inflater inflater = new Inflater(true); + private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true); + private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1]; + + private volatile Transformation next; + private volatile boolean skipDecompression = false; + private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private volatile boolean firstCompressedFrameWritten = false; + // Flag to track if a message is completely empty + private volatile boolean emptyMessage = true; + + static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) { + // Accept the first preference that the endpoint is able to support + for (List<Parameter> preference : preferences) { + boolean ok = true; + boolean serverContextTakeover = true; + int serverMaxWindowBits = -1; + boolean clientContextTakeover = true; + int clientMaxWindowBits = -1; + + for (Parameter param : preference) { + if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) { + if (serverContextTakeover) { + serverContextTakeover = false; + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + SERVER_NO_CONTEXT_TAKEOVER )); + } + } else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) { + if (clientContextTakeover) { + clientContextTakeover = false; + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + CLIENT_NO_CONTEXT_TAKEOVER )); + } + } else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) { + if (serverMaxWindowBits == -1) { + serverMaxWindowBits = Integer.parseInt(param.getValue()); + if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) { + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.invalidWindowSize", + SERVER_MAX_WINDOW_BITS, + Integer.valueOf(serverMaxWindowBits))); + } + // Java SE API (as of Java 8) does not expose the API to + // control the Window size. It is effectively hard-coded + // to 15 + if (isServer && serverMaxWindowBits != 15) { + ok = false; + break; + // Note server window size is not an issue for the + // client since the client will assume 15 and if the + // server uses a smaller window everything will + // still work + } + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + SERVER_MAX_WINDOW_BITS )); + } + } else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) { + if (clientMaxWindowBits == -1) { + if (param.getValue() == null) { + // Hint to server that the client supports this + // option. Java SE API (as of Java 8) does not + // expose the API to control the Window size. It is + // effectively hard-coded to 15 + clientMaxWindowBits = 15; + } else { + clientMaxWindowBits = Integer.parseInt(param.getValue()); + if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) { + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.invalidWindowSize", + CLIENT_MAX_WINDOW_BITS, + Integer.valueOf(clientMaxWindowBits))); + } + } + // Java SE API (as of Java 8) does not expose the API to + // control the Window size. It is effectively hard-coded + // to 15 + if (!isServer && clientMaxWindowBits != 15) { + ok = false; + break; + // Note client window size is not an issue for the + // server since the server will assume 15 and if the + // client uses a smaller window everything will + // still work + } + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + CLIENT_MAX_WINDOW_BITS )); + } + } else { + // Unknown parameter + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.unknownParameter", param.getName())); + } + } + if (ok) { + return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits, + clientContextTakeover, clientMaxWindowBits, isServer); + } + } + // Failed to negotiate agreeable terms + return null; + } + + + private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits, + boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) { + this.serverContextTakeover = serverContextTakeover; + this.serverMaxWindowBits = serverMaxWindowBits; + this.clientContextTakeover = clientContextTakeover; + this.clientMaxWindowBits = clientMaxWindowBits; + this.isServer = isServer; + } + + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) + throws IOException { + // Control frames are never compressed and may appear in the middle of + // a WebSocket method. Pass them straight through. + if (Util.isControl(opCode)) { + return next.getMoreData(opCode, fin, rsv, dest); + } + + if (!Util.isContinuation(opCode)) { + // First frame in new message + skipDecompression = (rsv & RSV_BITMASK) == 0; + } + + // Pass uncompressed frames straight through. + if (skipDecompression) { + return next.getMoreData(opCode, fin, rsv, dest); + } + + int written; + boolean usedEomBytes = false; + + while (dest.remaining() > 0) { + // Space available in destination. Try and fill it. + try { + written = inflater.inflate( + dest.array(), dest.arrayOffset() + dest.position(), dest.remaining()); + } catch (DataFormatException e) { + throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e); + } + dest.position(dest.position() + written); + + if (inflater.needsInput() && !usedEomBytes ) { + if (dest.hasRemaining()) { + readBuffer.clear(); + TransformationResult nextResult = + next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer); + inflater.setInput( + readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position()); + if (TransformationResult.UNDERFLOW.equals(nextResult)) { + return nextResult; + } else if (TransformationResult.END_OF_FRAME.equals(nextResult) && + readBuffer.position() == 0) { + if (fin) { + inflater.setInput(EOM_BYTES); + usedEomBytes = true; + } else { + return TransformationResult.END_OF_FRAME; + } + } + } + } else if (written == 0) { + if (fin && (isServer && !clientContextTakeover || + !isServer && !serverContextTakeover)) { + inflater.reset(); + } + return TransformationResult.END_OF_FRAME; + } + } + + return TransformationResult.OVERFLOW; + } + + + @Override + public boolean validateRsv(int rsv, byte opCode) { + if (Util.isControl(opCode)) { + if ((rsv & RSV_BITMASK) != 0) { + return false; + } else { + if (next == null) { + return true; + } else { + return next.validateRsv(rsv, opCode); + } + } + } else { + int rsvNext = rsv; + if ((rsv & RSV_BITMASK) != 0) { + rsvNext = rsv ^ RSV_BITMASK; + } + if (next == null) { + return true; + } else { + return next.validateRsv(rsvNext, opCode); + } + } + } + + + @Override + public Extension getExtensionResponse() { + Extension result = new WsExtension(NAME); + + List<Extension.Parameter> params = result.getParameters(); + + if (!serverContextTakeover) { + params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null)); + } + if (serverMaxWindowBits != -1) { + params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS, + Integer.toString(serverMaxWindowBits))); + } + if (!clientContextTakeover) { + params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null)); + } + if (clientMaxWindowBits != -1) { + params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS, + Integer.toString(clientMaxWindowBits))); + } + + return result; + } + + + @Override + public void setNext(Transformation t) { + if (next == null) { + this.next = t; + } else { + next.setNext(t); + } + } + + + @Override + public boolean validateRsvBits(int i) { + if ((i & RSV_BITMASK) != 0) { + return false; + } + if (next == null) { + return true; + } else { + return next.validateRsvBits(i | RSV_BITMASK); + } + } + + + @Override + public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) { + List<MessagePart> allCompressedParts = new ArrayList<>(); + + for (MessagePart uncompressedPart : uncompressedParts) { + byte opCode = uncompressedPart.getOpCode(); + boolean emptyPart = uncompressedPart.getPayload().limit() == 0; + emptyMessage = emptyMessage && emptyPart; + if (Util.isControl(opCode)) { + // Control messages can appear in the middle of other messages + // and must not be compressed. Pass it straight through + allCompressedParts.add(uncompressedPart); + } else if (emptyMessage && uncompressedPart.isFin()) { + // Zero length messages can't be compressed so pass the + // final (empty) part straight through. + allCompressedParts.add(uncompressedPart); + } else { + List<MessagePart> compressedParts = new ArrayList<>(); + ByteBuffer uncompressedPayload = uncompressedPart.getPayload(); + SendHandler uncompressedIntermediateHandler = + uncompressedPart.getIntermediateHandler(); + + deflater.setInput(uncompressedPayload.array(), + uncompressedPayload.arrayOffset() + uncompressedPayload.position(), + uncompressedPayload.remaining()); + + int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH); + boolean deflateRequired = true; + + while (deflateRequired) { + ByteBuffer compressedPayload = writeBuffer; + + int written = deflater.deflate(compressedPayload.array(), + compressedPayload.arrayOffset() + compressedPayload.position(), + compressedPayload.remaining(), flush); + compressedPayload.position(compressedPayload.position() + written); + + if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) { + // This message part has been fully processed by the + // deflater. Fire the send handler for this message part + // and move on to the next message part. + break; + } + + // If this point is reached, a new compressed message part + // will be created... + MessagePart compressedPart; + + // .. and a new writeBuffer will be required. + writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + + // Flip the compressed payload ready for writing + compressedPayload.flip(); + + boolean fin = uncompressedPart.isFin(); + boolean full = compressedPayload.limit() == compressedPayload.capacity(); + boolean needsInput = deflater.needsInput(); + long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry(); + + if (fin && !full && needsInput) { + // End of compressed message. Drop EOM bytes and output. + compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length); + compressedPart = new MessagePart(true, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + deflateRequired = false; + startNewMessage(); + } else if (full && !needsInput) { + // Write buffer full and input message not fully read. + // Output and start new compressed part. + compressedPart = new MessagePart(false, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + } else if (!fin && full && needsInput) { + // Write buffer full and input message not fully read. + // Output and get more data. + compressedPart = new MessagePart(false, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + deflateRequired = false; + } else if (fin && full && needsInput) { + // Write buffer full. Input fully read. Deflater may be + // in one of four states: + // - output complete (just happened to align with end of + // buffer + // - in middle of EOM bytes + // - about to write EOM bytes + // - more data to write + int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH); + if (eomBufferWritten < EOM_BUFFER.length) { + // EOM has just been completed + compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten); + compressedPart = new MessagePart(true, + getRsv(uncompressedPart), opCode, compressedPayload, + uncompressedIntermediateHandler, uncompressedIntermediateHandler, + blockingWriteTimeoutExpiry); + deflateRequired = false; + startNewMessage(); + } else { + // More data to write + // Copy bytes to new write buffer + writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten); + compressedPart = new MessagePart(false, + getRsv(uncompressedPart), opCode, compressedPayload, + uncompressedIntermediateHandler, uncompressedIntermediateHandler, + blockingWriteTimeoutExpiry); + } + } else { + throw new IllegalStateException("Should never happen"); + } + + // Add the newly created compressed part to the set of parts + // to pass on to the next transformation. + compressedParts.add(compressedPart); + } + + SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler(); + int size = compressedParts.size(); + if (size > 0) { + compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler); + } + + allCompressedParts.addAll(compressedParts); + } + } + + if (next == null) { + return allCompressedParts; + } else { + return next.sendMessagePart(allCompressedParts); + } + } + + + private void startNewMessage() { + firstCompressedFrameWritten = false; + emptyMessage = true; + if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) { + deflater.reset(); + } + } + + + private int getRsv(MessagePart uncompressedMessagePart) { + int result = uncompressedMessagePart.getRsv(); + if (!firstCompressedFrameWritten) { + result += RSV_BITMASK; + firstCompressedFrameWritten = true; + } + return result; + } + + + @Override + public void close() { + // There will always be a next transformation + next.close(); + inflater.end(); + deflater.end(); + } +} diff --git a/src/java/nginx/unit/websocket/ReadBufferOverflowException.java b/src/java/nginx/unit/websocket/ReadBufferOverflowException.java new file mode 100644 index 00000000..9ce7ac27 --- /dev/null +++ b/src/java/nginx/unit/websocket/ReadBufferOverflowException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; + +public class ReadBufferOverflowException extends IOException { + + private static final long serialVersionUID = 1L; + + private final int minBufferSize; + + public ReadBufferOverflowException(int minBufferSize) { + this.minBufferSize = minBufferSize; + } + + public int getMinBufferSize() { + return minBufferSize; + } +} diff --git a/src/java/nginx/unit/websocket/Transformation.java b/src/java/nginx/unit/websocket/Transformation.java new file mode 100644 index 00000000..45474c7d --- /dev/null +++ b/src/java/nginx/unit/websocket/Transformation.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import javax.websocket.Extension; + +/** + * The internal representation of the transformation that a WebSocket extension + * performs on a message. + */ +public interface Transformation { + + /** + * Sets the next transformation in the pipeline. + * @param t The next transformation + */ + void setNext(Transformation t); + + /** + * Validate that the RSV bit(s) required by this transformation are not + * being used by another extension. The implementation is expected to set + * any bits it requires before passing the set of in-use bits to the next + * transformation. + * + * @param i The RSV bits marked as in use so far as an int in the + * range zero to seven with RSV1 as the MSB and RSV3 as the + * LSB + * + * @return <code>true</code> if the combination of RSV bits used by the + * transformations in the pipeline do not conflict otherwise + * <code>false</code> + */ + boolean validateRsvBits(int i); + + /** + * Obtain the extension that describes the information to be returned to the + * client. + * + * @return The extension information that describes the parameters that have + * been agreed for this transformation + */ + Extension getExtensionResponse(); + + /** + * Obtain more input data. + * + * @param opCode The opcode for the frame currently being processed + * @param fin Is this the final frame in this WebSocket message? + * @param rsv The reserved bits for the frame currently being + * processed + * @param dest The buffer in which the data is to be written + * + * @return The result of trying to read more data from the transform + * + * @throws IOException If an I/O error occurs while reading data from the + * transform + */ + TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) throws IOException; + + /** + * Validates the RSV and opcode combination (assumed to have been extracted + * from a WebSocket Frame) for this extension. The implementation is + * expected to unset any RSV bits it has validated before passing the + * remaining RSV bits to the next transformation in the pipeline. + * + * @param rsv The RSV bits received as an int in the range zero to + * seven with RSV1 as the MSB and RSV3 as the LSB + * @param opCode The opCode received + * + * @return <code>true</code> if the RSV is valid otherwise + * <code>false</code> + */ + boolean validateRsv(int rsv, byte opCode); + + /** + * Takes the provided list of messages, transforms them, passes the + * transformed list on to the next transformation (if any) and then returns + * the resulting list of message parts after all of the transformations have + * been applied. + * + * @param messageParts The list of messages to be transformed + * + * @return The list of messages after this any any subsequent + * transformations have been applied. The size of the returned list + * may be bigger or smaller than the size of the input list + */ + List<MessagePart> sendMessagePart(List<MessagePart> messageParts); + + /** + * Clean-up any resources that were used by the transformation. + */ + void close(); +} diff --git a/src/java/nginx/unit/websocket/TransformationFactory.java b/src/java/nginx/unit/websocket/TransformationFactory.java new file mode 100644 index 00000000..fac04555 --- /dev/null +++ b/src/java/nginx/unit/websocket/TransformationFactory.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.List; + +import javax.websocket.Extension; + +import org.apache.tomcat.util.res.StringManager; + +public class TransformationFactory { + + private static final StringManager sm = StringManager.getManager(TransformationFactory.class); + + private static final TransformationFactory factory = new TransformationFactory(); + + private TransformationFactory() { + // Hide default constructor + } + + public static TransformationFactory getInstance() { + return factory; + } + + public Transformation create(String name, List<List<Extension.Parameter>> preferences, + boolean isServer) { + if (PerMessageDeflate.NAME.equals(name)) { + return PerMessageDeflate.negotiate(preferences, isServer); + } + if (Constants.ALLOW_UNSUPPORTED_EXTENSIONS) { + return null; + } else { + throw new IllegalArgumentException( + sm.getString("transformerFactory.unsupportedExtension", name)); + } + } +} diff --git a/src/java/nginx/unit/websocket/TransformationResult.java b/src/java/nginx/unit/websocket/TransformationResult.java new file mode 100644 index 00000000..0de35e55 --- /dev/null +++ b/src/java/nginx/unit/websocket/TransformationResult.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public enum TransformationResult { + /** + * The end of the available data was reached before the WebSocket frame was + * completely read. + */ + UNDERFLOW, + + /** + * The provided destination buffer was filled before all of the available + * data from the WebSocket frame could be processed. + */ + OVERFLOW, + + /** + * The end of the WebSocket frame was reached and all the data from that + * frame processed into the provided destination buffer. + */ + END_OF_FRAME +} diff --git a/src/java/nginx/unit/websocket/Util.java b/src/java/nginx/unit/websocket/Util.java new file mode 100644 index 00000000..6acf3ade --- /dev/null +++ b/src/java/nginx/unit/websocket/Util.java @@ -0,0 +1,666 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.InputStream; +import java.io.Reader; +import java.lang.reflect.GenericArrayType; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; + +import javax.websocket.CloseReason.CloseCode; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Binary; +import javax.websocket.Decoder.BinaryStream; +import javax.websocket.Decoder.Text; +import javax.websocket.Decoder.TextStream; +import javax.websocket.DeploymentException; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.PongMessage; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.pojo.PojoMessageHandlerPartialBinary; +import nginx.unit.websocket.pojo.PojoMessageHandlerWholeBinary; +import nginx.unit.websocket.pojo.PojoMessageHandlerWholeText; + +/** + * Utility class for internal use only within the + * {@link nginx.unit.websocket} package. + */ +public class Util { + + private static final StringManager sm = StringManager.getManager(Util.class); + private static final Queue<SecureRandom> randoms = + new ConcurrentLinkedQueue<>(); + + private Util() { + // Hide default constructor + } + + + static boolean isControl(byte opCode) { + return (opCode & 0x08) != 0; + } + + + static boolean isText(byte opCode) { + return opCode == Constants.OPCODE_TEXT; + } + + + static boolean isContinuation(byte opCode) { + return opCode == Constants.OPCODE_CONTINUATION; + } + + + static CloseCode getCloseCode(int code) { + if (code > 2999 && code < 5000) { + return CloseCodes.getCloseCode(code); + } + switch (code) { + case 1000: + return CloseCodes.NORMAL_CLOSURE; + case 1001: + return CloseCodes.GOING_AWAY; + case 1002: + return CloseCodes.PROTOCOL_ERROR; + case 1003: + return CloseCodes.CANNOT_ACCEPT; + case 1004: + // Should not be used in a close frame + // return CloseCodes.RESERVED; + return CloseCodes.PROTOCOL_ERROR; + case 1005: + // Should not be used in a close frame + // return CloseCodes.NO_STATUS_CODE; + return CloseCodes.PROTOCOL_ERROR; + case 1006: + // Should not be used in a close frame + // return CloseCodes.CLOSED_ABNORMALLY; + return CloseCodes.PROTOCOL_ERROR; + case 1007: + return CloseCodes.NOT_CONSISTENT; + case 1008: + return CloseCodes.VIOLATED_POLICY; + case 1009: + return CloseCodes.TOO_BIG; + case 1010: + return CloseCodes.NO_EXTENSION; + case 1011: + return CloseCodes.UNEXPECTED_CONDITION; + case 1012: + // Not in RFC6455 + // return CloseCodes.SERVICE_RESTART; + return CloseCodes.PROTOCOL_ERROR; + case 1013: + // Not in RFC6455 + // return CloseCodes.TRY_AGAIN_LATER; + return CloseCodes.PROTOCOL_ERROR; + case 1015: + // Should not be used in a close frame + // return CloseCodes.TLS_HANDSHAKE_FAILURE; + return CloseCodes.PROTOCOL_ERROR; + default: + return CloseCodes.PROTOCOL_ERROR; + } + } + + + static byte[] generateMask() { + // SecureRandom is not thread-safe so need to make sure only one thread + // uses it at a time. In theory, the pool could grow to the same size + // as the number of request processing threads. In reality it will be + // a lot smaller. + + // Get a SecureRandom from the pool + SecureRandom sr = randoms.poll(); + + // If one isn't available, generate a new one + if (sr == null) { + try { + sr = SecureRandom.getInstance("SHA1PRNG"); + } catch (NoSuchAlgorithmException e) { + // Fall back to platform default + sr = new SecureRandom(); + } + } + + // Generate the mask + byte[] result = new byte[4]; + sr.nextBytes(result); + + // Put the SecureRandom back in the poll + randoms.add(sr); + + return result; + } + + + static Class<?> getMessageType(MessageHandler listener) { + return Util.getGenericType(MessageHandler.class, + listener.getClass()).getClazz(); + } + + + private static Class<?> getDecoderType(Class<? extends Decoder> decoder) { + return Util.getGenericType(Decoder.class, decoder).getClazz(); + } + + + static Class<?> getEncoderType(Class<? extends Encoder> encoder) { + return Util.getGenericType(Encoder.class, encoder).getClazz(); + } + + + private static <T> TypeResult getGenericType(Class<T> type, + Class<? extends T> clazz) { + + // Look to see if this class implements the interface of interest + + // Get all the interfaces + Type[] interfaces = clazz.getGenericInterfaces(); + for (Type iface : interfaces) { + // Only need to check interfaces that use generics + if (iface instanceof ParameterizedType) { + ParameterizedType pi = (ParameterizedType) iface; + // Look for the interface of interest + if (pi.getRawType() instanceof Class) { + if (type.isAssignableFrom((Class<?>) pi.getRawType())) { + return getTypeParameter( + clazz, pi.getActualTypeArguments()[0]); + } + } + } + } + + // Interface not found on this class. Look at the superclass. + @SuppressWarnings("unchecked") + Class<? extends T> superClazz = + (Class<? extends T>) clazz.getSuperclass(); + if (superClazz == null) { + // Finished looking up the class hierarchy without finding anything + return null; + } + + TypeResult superClassTypeResult = getGenericType(type, superClazz); + int dimension = superClassTypeResult.getDimension(); + if (superClassTypeResult.getIndex() == -1 && dimension == 0) { + // Superclass implements interface and defines explicit type for + // the interface of interest + return superClassTypeResult; + } + + if (superClassTypeResult.getIndex() > -1) { + // Superclass implements interface and defines unknown type for + // the interface of interest + // Map that unknown type to the generic types defined in this class + ParameterizedType superClassType = + (ParameterizedType) clazz.getGenericSuperclass(); + TypeResult result = getTypeParameter(clazz, + superClassType.getActualTypeArguments()[ + superClassTypeResult.getIndex()]); + result.incrementDimension(superClassTypeResult.getDimension()); + if (result.getClazz() != null && result.getDimension() > 0) { + superClassTypeResult = result; + } else { + return result; + } + } + + if (superClassTypeResult.getDimension() > 0) { + StringBuilder className = new StringBuilder(); + for (int i = 0; i < dimension; i++) { + className.append('['); + } + className.append('L'); + className.append(superClassTypeResult.getClazz().getCanonicalName()); + className.append(';'); + + Class<?> arrayClazz; + try { + arrayClazz = Class.forName(className.toString()); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException(e); + } + + return new TypeResult(arrayClazz, -1, 0); + } + + // Error will be logged further up the call stack + return null; + } + + + /* + * For a generic parameter, return either the Class used or if the type + * is unknown, the index for the type in definition of the class + */ + private static TypeResult getTypeParameter(Class<?> clazz, Type argType) { + if (argType instanceof Class<?>) { + return new TypeResult((Class<?>) argType, -1, 0); + } else if (argType instanceof ParameterizedType) { + return new TypeResult((Class<?>)((ParameterizedType) argType).getRawType(), -1, 0); + } else if (argType instanceof GenericArrayType) { + Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType(); + TypeResult result = getTypeParameter(clazz, arrayElementType); + result.incrementDimension(1); + return result; + } else { + TypeVariable<?>[] tvs = clazz.getTypeParameters(); + for (int i = 0; i < tvs.length; i++) { + if (tvs[i].equals(argType)) { + return new TypeResult(null, i, 0); + } + } + return null; + } + } + + + public static boolean isPrimitive(Class<?> clazz) { + if (clazz.isPrimitive()) { + return true; + } else if(clazz.equals(Boolean.class) || + clazz.equals(Byte.class) || + clazz.equals(Character.class) || + clazz.equals(Double.class) || + clazz.equals(Float.class) || + clazz.equals(Integer.class) || + clazz.equals(Long.class) || + clazz.equals(Short.class)) { + return true; + } + return false; + } + + + public static Object coerceToType(Class<?> type, String value) { + if (type.equals(String.class)) { + return value; + } else if (type.equals(boolean.class) || type.equals(Boolean.class)) { + return Boolean.valueOf(value); + } else if (type.equals(byte.class) || type.equals(Byte.class)) { + return Byte.valueOf(value); + } else if (type.equals(char.class) || type.equals(Character.class)) { + return Character.valueOf(value.charAt(0)); + } else if (type.equals(double.class) || type.equals(Double.class)) { + return Double.valueOf(value); + } else if (type.equals(float.class) || type.equals(Float.class)) { + return Float.valueOf(value); + } else if (type.equals(int.class) || type.equals(Integer.class)) { + return Integer.valueOf(value); + } else if (type.equals(long.class) || type.equals(Long.class)) { + return Long.valueOf(value); + } else if (type.equals(short.class) || type.equals(Short.class)) { + return Short.valueOf(value); + } else { + throw new IllegalArgumentException(sm.getString( + "util.invalidType", value, type.getName())); + } + } + + + public static List<DecoderEntry> getDecoders( + List<Class<? extends Decoder>> decoderClazzes) + throws DeploymentException{ + + List<DecoderEntry> result = new ArrayList<>(); + if (decoderClazzes != null) { + for (Class<? extends Decoder> decoderClazz : decoderClazzes) { + // Need to instantiate decoder to ensure it is valid and that + // deployment can be failed if it is not + @SuppressWarnings("unused") + Decoder instance; + try { + instance = decoderClazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException( + sm.getString("pojoMethodMapping.invalidDecoder", + decoderClazz.getName()), e); + } + DecoderEntry entry = new DecoderEntry( + Util.getDecoderType(decoderClazz), decoderClazz); + result.add(entry); + } + } + + return result; + } + + + static Set<MessageHandlerResult> getMessageHandlers(Class<?> target, + MessageHandler listener, EndpointConfig endpointConfig, + Session session) { + + // Will never be more than 2 types + Set<MessageHandlerResult> results = new HashSet<>(2); + + // Simple cases - handlers already accepts one of the types expected by + // the frame handling code + if (String.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.TEXT); + results.add(result); + } else if (ByteBuffer.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.BINARY); + results.add(result); + } else if (PongMessage.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.PONG); + results.add(result); + // Handler needs wrapping and optional decoder to convert it to one of + // the types expected by the frame handling code + } else if (byte[].class.isAssignableFrom(target)) { + boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass()); + MessageHandlerResult result = new MessageHandlerResult( + whole ? new PojoMessageHandlerWholeBinary(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, true), + new Object[1], 0, true, -1, false, -1) : + new PojoMessageHandlerPartialBinary(listener, + getOnMessagePartialMethod(listener), session, + new Object[2], 0, true, 1, -1, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } else if (InputStream.class.isAssignableFrom(target)) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeBinary(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, true), + new Object[1], 0, true, -1, true, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } else if (Reader.class.isAssignableFrom(target)) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeText(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, false), + new Object[1], 0, true, -1, -1), + MessageHandlerResultType.TEXT); + results.add(result); + } else { + // Handler needs wrapping and requires decoder to convert it to one + // of the types expected by the frame handling code + DecoderMatch decoderMatch = matchDecoders(target, endpointConfig); + Method m = getOnMessageMethod(listener); + if (decoderMatch.getBinaryDecoders().size() > 0) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeBinary(listener, m, session, + endpointConfig, + decoderMatch.getBinaryDecoders(), new Object[1], + 0, false, -1, false, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } + if (decoderMatch.getTextDecoders().size() > 0) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeText(listener, m, session, + endpointConfig, + decoderMatch.getTextDecoders(), new Object[1], + 0, false, -1, -1), + MessageHandlerResultType.TEXT); + results.add(result); + } + } + + if (results.size() == 0) { + throw new IllegalArgumentException( + sm.getString("wsSession.unknownHandler", listener, target)); + } + + return results; + } + + private static List<Class<? extends Decoder>> matchDecoders(Class<?> target, + EndpointConfig endpointConfig, boolean binary) { + DecoderMatch decoderMatch = matchDecoders(target, endpointConfig); + if (binary) { + if (decoderMatch.getBinaryDecoders().size() > 0) { + return decoderMatch.getBinaryDecoders(); + } + } else if (decoderMatch.getTextDecoders().size() > 0) { + return decoderMatch.getTextDecoders(); + } + return null; + } + + private static DecoderMatch matchDecoders(Class<?> target, + EndpointConfig endpointConfig) { + DecoderMatch decoderMatch; + try { + List<Class<? extends Decoder>> decoders = + endpointConfig.getDecoders(); + List<DecoderEntry> decoderEntries = getDecoders(decoders); + decoderMatch = new DecoderMatch(target, decoderEntries); + } catch (DeploymentException e) { + throw new IllegalArgumentException(e); + } + return decoderMatch; + } + + public static void parseExtensionHeader(List<Extension> extensions, + String header) { + // The relevant ABNF for the Sec-WebSocket-Extensions is as follows: + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ; When using the quoted-string syntax variant, the value + // ; after quoted-string unescaping MUST conform to the + // ; 'token' ABNF. + // + // The limiting of parameter values to tokens or "quoted tokens" makes + // the parsing of the header significantly simpler and allows a number + // of short-cuts to be taken. + + // Step one, split the header into individual extensions using ',' as a + // separator + String unparsedExtensions[] = header.split(","); + for (String unparsedExtension : unparsedExtensions) { + // Step two, split the extension into the registered name and + // parameter/value pairs using ';' as a separator + String unparsedParameters[] = unparsedExtension.split(";"); + WsExtension extension = new WsExtension(unparsedParameters[0].trim()); + + for (int i = 1; i < unparsedParameters.length; i++) { + int equalsPos = unparsedParameters[i].indexOf('='); + String name; + String value; + if (equalsPos == -1) { + name = unparsedParameters[i].trim(); + value = null; + } else { + name = unparsedParameters[i].substring(0, equalsPos).trim(); + value = unparsedParameters[i].substring(equalsPos + 1).trim(); + int len = value.length(); + if (len > 1) { + if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') { + value = value.substring(1, value.length() - 1); + } + } + } + // Make sure value doesn't contain any of the delimiters since + // that would indicate something went wrong + if (containsDelims(name) || containsDelims(value)) { + throw new IllegalArgumentException(sm.getString( + "util.notToken", name, value)); + } + if (value != null && + (value.indexOf(',') > -1 || value.indexOf(';') > -1 || + value.indexOf('\"') > -1 || value.indexOf('=') > -1)) { + throw new IllegalArgumentException(sm.getString("", value)); + } + extension.addParameter(new WsExtensionParameter(name, value)); + } + extensions.add(extension); + } + } + + + private static boolean containsDelims(String input) { + if (input == null || input.length() == 0) { + return false; + } + for (char c : input.toCharArray()) { + switch (c) { + case ',': + case ';': + case '\"': + case '=': + return true; + default: + // NO_OP + } + + } + return false; + } + + private static Method getOnMessageMethod(MessageHandler listener) { + try { + return listener.getClass().getMethod("onMessage", Object.class); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + sm.getString("util.invalidMessageHandler"), e); + } + } + + private static Method getOnMessagePartialMethod(MessageHandler listener) { + try { + return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + sm.getString("util.invalidMessageHandler"), e); + } + } + + + public static class DecoderMatch { + + private final List<Class<? extends Decoder>> textDecoders = + new ArrayList<>(); + private final List<Class<? extends Decoder>> binaryDecoders = + new ArrayList<>(); + private final Class<?> target; + + public DecoderMatch(Class<?> target, List<DecoderEntry> decoderEntries) { + this.target = target; + for (DecoderEntry decoderEntry : decoderEntries) { + if (decoderEntry.getClazz().isAssignableFrom(target)) { + if (Binary.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + binaryDecoders.add(decoderEntry.getDecoderClazz()); + // willDecode() method means this decoder may or may not + // decode a message so need to carry on checking for + // other matches + } else if (BinaryStream.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + binaryDecoders.add(decoderEntry.getDecoderClazz()); + // Stream decoders have to process the message so no + // more decoders can be matched + break; + } else if (Text.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + textDecoders.add(decoderEntry.getDecoderClazz()); + // willDecode() method means this decoder may or may not + // decode a message so need to carry on checking for + // other matches + } else if (TextStream.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + textDecoders.add(decoderEntry.getDecoderClazz()); + // Stream decoders have to process the message so no + // more decoders can be matched + break; + } else { + throw new IllegalArgumentException( + sm.getString("util.unknownDecoderType")); + } + } + } + } + + + public List<Class<? extends Decoder>> getTextDecoders() { + return textDecoders; + } + + + public List<Class<? extends Decoder>> getBinaryDecoders() { + return binaryDecoders; + } + + + public Class<?> getTarget() { + return target; + } + + + public boolean hasMatches() { + return (textDecoders.size() > 0) || (binaryDecoders.size() > 0); + } + } + + + private static class TypeResult { + private final Class<?> clazz; + private final int index; + private int dimension; + + public TypeResult(Class<?> clazz, int index, int dimension) { + this.clazz= clazz; + this.index = index; + this.dimension = dimension; + } + + public Class<?> getClazz() { + return clazz; + } + + public int getIndex() { + return index; + } + + public int getDimension() { + return dimension; + } + + public void incrementDimension(int inc) { + dimension += inc; + } + } +} diff --git a/src/java/nginx/unit/websocket/WrappedMessageHandler.java b/src/java/nginx/unit/websocket/WrappedMessageHandler.java new file mode 100644 index 00000000..2557a73e --- /dev/null +++ b/src/java/nginx/unit/websocket/WrappedMessageHandler.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.MessageHandler; + +public interface WrappedMessageHandler { + long getMaxMessageSize(); + + MessageHandler getWrappedHandler(); +} diff --git a/src/java/nginx/unit/websocket/WsContainerProvider.java b/src/java/nginx/unit/websocket/WsContainerProvider.java new file mode 100644 index 00000000..f8a404a1 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsContainerProvider.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.ContainerProvider; +import javax.websocket.WebSocketContainer; + +public class WsContainerProvider extends ContainerProvider { + + @Override + protected WebSocketContainer getContainer() { + return new WsWebSocketContainer(); + } +} diff --git a/src/java/nginx/unit/websocket/WsExtension.java b/src/java/nginx/unit/websocket/WsExtension.java new file mode 100644 index 00000000..3846feb1 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsExtension.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.Extension; + +public class WsExtension implements Extension { + + private final String name; + private final List<Parameter> parameters = new ArrayList<>(); + + WsExtension(String name) { + this.name = name; + } + + void addParameter(Parameter parameter) { + parameters.add(parameter); + } + + @Override + public String getName() { + return name; + } + + @Override + public List<Parameter> getParameters() { + return parameters; + } +} diff --git a/src/java/nginx/unit/websocket/WsExtensionParameter.java b/src/java/nginx/unit/websocket/WsExtensionParameter.java new file mode 100644 index 00000000..9b82f1c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsExtensionParameter.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.Extension.Parameter; + +public class WsExtensionParameter implements Parameter { + + private final String name; + private final String value; + + WsExtensionParameter(String name, String value) { + this.name = name; + this.value = value; + } + + @Override + public String getName() { + return name; + } + + @Override + public String getValue() { + return value; + } +} diff --git a/src/java/nginx/unit/websocket/WsFrameBase.java b/src/java/nginx/unit/websocket/WsFrameBase.java new file mode 100644 index 00000000..06d20bf4 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsFrameBase.java @@ -0,0 +1,1010 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; +import java.util.List; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.PongMessage; + +import org.apache.juli.logging.Log; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.buf.Utf8Decoder; +import org.apache.tomcat.util.res.StringManager; + +/** + * Takes the ServletInputStream, processes the WebSocket frames it contains and + * extracts the messages. WebSocket Pings received will be responded to + * automatically without any action required by the application. + */ +public abstract class WsFrameBase { + + private static final StringManager sm = StringManager.getManager(WsFrameBase.class); + + // Connection level attributes + protected final WsSession wsSession; + protected final ByteBuffer inputBuffer; + private final Transformation transformation; + + // Attributes for control messages + // Control messages can appear in the middle of other messages so need + // separate attributes + private final ByteBuffer controlBufferBinary = ByteBuffer.allocate(125); + private final CharBuffer controlBufferText = CharBuffer.allocate(125); + + // Attributes of the current message + private final CharsetDecoder utf8DecoderControl = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + private boolean continuationExpected = false; + private boolean textMessage = false; + private ByteBuffer messageBufferBinary; + private CharBuffer messageBufferText; + // Cache the message handler in force when the message starts so it is used + // consistently for the entire message + private MessageHandler binaryMsgHandler = null; + private MessageHandler textMsgHandler = null; + + // Attributes of the current frame + private boolean fin = false; + private int rsv = 0; + private byte opCode = 0; + private final byte[] mask = new byte[4]; + private int maskIndex = 0; + private long payloadLength = 0; + private volatile long payloadWritten = 0; + + // Attributes tracking state + private volatile State state = State.NEW_FRAME; + private volatile boolean open = true; + + private static final AtomicReferenceFieldUpdater<WsFrameBase, ReadState> READ_STATE_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(WsFrameBase.class, ReadState.class, "readState"); + private volatile ReadState readState = ReadState.WAITING; + + public WsFrameBase(WsSession wsSession, Transformation transformation) { + inputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + inputBuffer.position(0).limit(0); + messageBufferBinary = ByteBuffer.allocate(wsSession.getMaxBinaryMessageBufferSize()); + messageBufferText = CharBuffer.allocate(wsSession.getMaxTextMessageBufferSize()); + this.wsSession = wsSession; + Transformation finalTransformation; + if (isMasked()) { + finalTransformation = new UnmaskTransformation(); + } else { + finalTransformation = new NoopTransformation(); + } + if (transformation == null) { + this.transformation = finalTransformation; + } else { + transformation.setNext(finalTransformation); + this.transformation = transformation; + } + } + + + protected void processInputBuffer() throws IOException { + while (!isSuspended()) { + wsSession.updateLastActive(); + if (state == State.NEW_FRAME) { + if (!processInitialHeader()) { + break; + } + // If a close frame has been received, no further data should + // have seen + if (!open) { + throw new IOException(sm.getString("wsFrame.closed")); + } + } + if (state == State.PARTIAL_HEADER) { + if (!processRemainingHeader()) { + break; + } + } + if (state == State.DATA) { + if (!processData()) { + break; + } + } + } + } + + + /** + * @return <code>true</code> if sufficient data was present to process all + * of the initial header + */ + private boolean processInitialHeader() throws IOException { + // Need at least two bytes of data to do this + if (inputBuffer.remaining() < 2) { + return false; + } + int b = inputBuffer.get(); + fin = (b & 0x80) != 0; + rsv = (b & 0x70) >>> 4; + opCode = (byte) (b & 0x0F); + if (!transformation.validateRsv(rsv, opCode)) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.wrongRsv", Integer.valueOf(rsv), Integer.valueOf(opCode)))); + } + + if (Util.isControl(opCode)) { + if (!fin) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlFragmented"))); + } + if (opCode != Constants.OPCODE_PING && + opCode != Constants.OPCODE_PONG && + opCode != Constants.OPCODE_CLOSE) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + } else { + if (continuationExpected) { + if (!Util.isContinuation(opCode)) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.noContinuation"))); + } + } else { + try { + if (opCode == Constants.OPCODE_BINARY) { + // New binary message + textMessage = false; + int size = wsSession.getMaxBinaryMessageBufferSize(); + if (size != messageBufferBinary.capacity()) { + messageBufferBinary = ByteBuffer.allocate(size); + } + binaryMsgHandler = wsSession.getBinaryMessageHandler(); + textMsgHandler = null; + } else if (opCode == Constants.OPCODE_TEXT) { + // New text message + textMessage = true; + int size = wsSession.getMaxTextMessageBufferSize(); + if (size != messageBufferText.capacity()) { + messageBufferText = CharBuffer.allocate(size); + } + binaryMsgHandler = null; + textMsgHandler = wsSession.getTextMessageHandler(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + } catch (IllegalStateException ise) { + // Thrown if the session is already closed + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.sessionClosed"))); + } + } + continuationExpected = !fin; + } + b = inputBuffer.get(); + // Client data must be masked + if ((b & 0x80) == 0 && isMasked()) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.notMasked"))); + } + payloadLength = b & 0x7F; + state = State.PARTIAL_HEADER; + if (getLog().isDebugEnabled()) { + getLog().debug(sm.getString("wsFrame.partialHeaderComplete", Boolean.toString(fin), + Integer.toString(rsv), Integer.toString(opCode), Long.toString(payloadLength))); + } + return true; + } + + + protected abstract boolean isMasked(); + protected abstract Log getLog(); + + + /** + * @return <code>true</code> if sufficient data was present to complete the + * processing of the header + */ + private boolean processRemainingHeader() throws IOException { + // Ignore the 2 bytes already read. 4 for the mask + int headerLength; + if (isMasked()) { + headerLength = 4; + } else { + headerLength = 0; + } + // Add additional bytes depending on length + if (payloadLength == 126) { + headerLength += 2; + } else if (payloadLength == 127) { + headerLength += 8; + } + if (inputBuffer.remaining() < headerLength) { + return false; + } + // Calculate new payload length if necessary + if (payloadLength == 126) { + payloadLength = byteArrayToLong(inputBuffer.array(), + inputBuffer.arrayOffset() + inputBuffer.position(), 2); + inputBuffer.position(inputBuffer.position() + 2); + } else if (payloadLength == 127) { + payloadLength = byteArrayToLong(inputBuffer.array(), + inputBuffer.arrayOffset() + inputBuffer.position(), 8); + inputBuffer.position(inputBuffer.position() + 8); + } + if (Util.isControl(opCode)) { + if (payloadLength > 125) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlPayloadTooBig", Long.valueOf(payloadLength)))); + } + if (!fin) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlNoFin"))); + } + } + if (isMasked()) { + inputBuffer.get(mask, 0, 4); + } + state = State.DATA; + return true; + } + + + private boolean processData() throws IOException { + boolean result; + if (Util.isControl(opCode)) { + result = processDataControl(); + } else if (textMessage) { + if (textMsgHandler == null) { + result = swallowInput(); + } else { + result = processDataText(); + } + } else { + if (binaryMsgHandler == null) { + result = swallowInput(); + } else { + result = processDataBinary(); + } + } + checkRoomPayload(); + return result; + } + + + private boolean processDataControl() throws IOException { + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, controlBufferBinary); + if (TransformationResult.UNDERFLOW.equals(tr)) { + return false; + } + // Control messages have fixed message size so + // TransformationResult.OVERFLOW is not possible here + + controlBufferBinary.flip(); + if (opCode == Constants.OPCODE_CLOSE) { + open = false; + String reason = null; + int code = CloseCodes.NORMAL_CLOSURE.getCode(); + if (controlBufferBinary.remaining() == 1) { + controlBufferBinary.clear(); + // Payload must be zero or 2+ bytes long + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.oneByteCloseCode"))); + } + if (controlBufferBinary.remaining() > 1) { + code = controlBufferBinary.getShort(); + if (controlBufferBinary.remaining() > 0) { + CoderResult cr = utf8DecoderControl.decode(controlBufferBinary, + controlBufferText, true); + if (cr.isError()) { + controlBufferBinary.clear(); + controlBufferText.clear(); + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidUtf8Close"))); + } + // There will be no overflow as the output buffer is big + // enough. There will be no underflow as all the data is + // passed to the decoder in a single call. + controlBufferText.flip(); + reason = controlBufferText.toString(); + } + } + wsSession.onClose(new CloseReason(Util.getCloseCode(code), reason)); + } else if (opCode == Constants.OPCODE_PING) { + if (wsSession.isOpen()) { + wsSession.getBasicRemote().sendPong(controlBufferBinary); + } + } else if (opCode == Constants.OPCODE_PONG) { + MessageHandler.Whole<PongMessage> mhPong = wsSession.getPongMessageHandler(); + if (mhPong != null) { + try { + mhPong.onMessage(new WsPongMessage(controlBufferBinary)); + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + controlBufferBinary.clear(); + } + } + } else { + // Should have caught this earlier but just in case... + controlBufferBinary.clear(); + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + controlBufferBinary.clear(); + newFrame(); + return true; + } + + + @SuppressWarnings("unchecked") + protected void sendMessageText(boolean last) throws WsIOException { + if (textMsgHandler instanceof WrappedMessageHandler) { + long maxMessageSize = ((WrappedMessageHandler) textMsgHandler).getMaxMessageSize(); + if (maxMessageSize > -1 && messageBufferText.remaining() > maxMessageSize) { + throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.messageTooBig", + Long.valueOf(messageBufferText.remaining()), + Long.valueOf(maxMessageSize)))); + } + } + + try { + if (textMsgHandler instanceof MessageHandler.Partial<?>) { + ((MessageHandler.Partial<String>) textMsgHandler) + .onMessage(messageBufferText.toString(), last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole<String>) textMsgHandler) + .onMessage(messageBufferText.toString()); + } + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + messageBufferText.clear(); + } + } + + + private boolean processDataText() throws IOException { + // Copy the available data to the buffer + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + while (!TransformationResult.END_OF_FRAME.equals(tr)) { + // Frame not complete - we ran out of something + // Convert bytes to UTF-8 + messageBufferBinary.flip(); + while (true) { + CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText, + false); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow()) { + // Compact what we have to create as much space as possible + messageBufferBinary.compact(); + + // Need more input + // What did we run out of? + if (TransformationResult.OVERFLOW.equals(tr)) { + // Ran out of message buffer - exit inner loop and + // refill + break; + } else { + // TransformationResult.UNDERFLOW + // Ran out of input data - get some more + return false; + } + } + } + // Read more input data + tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + } + + messageBufferBinary.flip(); + boolean last = false; + // Frame is fully received + // Convert bytes to UTF-8 + while (true) { + CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText, + last); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow() && !last) { + // End of frame and possible message as well. + + if (continuationExpected) { + // If partial messages are supported, send what we have + // managed to decode + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } + messageBufferBinary.compact(); + newFrame(); + // Process next frame + return true; + } else { + // Make sure coder has flushed all output + last = true; + } + } else { + // End of message + messageBufferText.flip(); + sendMessageText(true); + + newMessage(); + return true; + } + } + } + + + private boolean processDataBinary() throws IOException { + // Copy the available data to the buffer + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + while (!TransformationResult.END_OF_FRAME.equals(tr)) { + // Frame not complete - what did we run out of? + if (TransformationResult.UNDERFLOW.equals(tr)) { + // Ran out of input data - get some more + return false; + } + + // Ran out of message buffer - flush it + if (!usePartial()) { + CloseReason cr = new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.bufferTooSmall", + Integer.valueOf(messageBufferBinary.capacity()), + Long.valueOf(payloadLength))); + throw new WsIOException(cr); + } + messageBufferBinary.flip(); + ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit()); + copy.put(messageBufferBinary); + copy.flip(); + sendMessageBinary(copy, false); + messageBufferBinary.clear(); + // Read more data + tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + } + + // Frame is fully received + // Send the message if either: + // - partial messages are supported + // - the message is complete + if (usePartial() || !continuationExpected) { + messageBufferBinary.flip(); + ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit()); + copy.put(messageBufferBinary); + copy.flip(); + sendMessageBinary(copy, !continuationExpected); + messageBufferBinary.clear(); + } + + if (continuationExpected) { + // More data for this message expected, start a new frame + newFrame(); + } else { + // Message is complete, start a new message + newMessage(); + } + + return true; + } + + + private void handleThrowableOnSend(Throwable t) throws WsIOException { + ExceptionUtils.handleThrowable(t); + wsSession.getLocal().onError(wsSession, t); + CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, + sm.getString("wsFrame.ioeTriggeredClose")); + throw new WsIOException(cr); + } + + + @SuppressWarnings("unchecked") + protected void sendMessageBinary(ByteBuffer msg, boolean last) throws WsIOException { + if (binaryMsgHandler instanceof WrappedMessageHandler) { + long maxMessageSize = ((WrappedMessageHandler) binaryMsgHandler).getMaxMessageSize(); + if (maxMessageSize > -1 && msg.remaining() > maxMessageSize) { + throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.messageTooBig", + Long.valueOf(msg.remaining()), + Long.valueOf(maxMessageSize)))); + } + } + try { + if (binaryMsgHandler instanceof MessageHandler.Partial<?>) { + ((MessageHandler.Partial<ByteBuffer>) binaryMsgHandler).onMessage(msg, last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole<ByteBuffer>) binaryMsgHandler).onMessage(msg); + } + } catch (Throwable t) { + handleThrowableOnSend(t); + } + } + + + private void newMessage() { + messageBufferBinary.clear(); + messageBufferText.clear(); + utf8DecoderMessage.reset(); + continuationExpected = false; + newFrame(); + } + + + private void newFrame() { + if (inputBuffer.remaining() == 0) { + inputBuffer.position(0).limit(0); + } + + maskIndex = 0; + payloadWritten = 0; + state = State.NEW_FRAME; + + // These get reset in processInitialHeader() + // fin, rsv, opCode, payloadLength, mask + + checkRoomHeaders(); + } + + + private void checkRoomHeaders() { + // Is the start of the current frame too near the end of the input + // buffer? + if (inputBuffer.capacity() - inputBuffer.position() < 131) { + // Limit based on a control frame with a full payload + makeRoom(); + } + } + + + private void checkRoomPayload() { + if (inputBuffer.capacity() - inputBuffer.position() - payloadLength + payloadWritten < 0) { + makeRoom(); + } + } + + + private void makeRoom() { + inputBuffer.compact(); + inputBuffer.flip(); + } + + + private boolean usePartial() { + if (Util.isControl(opCode)) { + return false; + } else if (textMessage) { + return textMsgHandler instanceof MessageHandler.Partial; + } else { + // Must be binary + return binaryMsgHandler instanceof MessageHandler.Partial; + } + } + + + private boolean swallowInput() { + long toSkip = Math.min(payloadLength - payloadWritten, inputBuffer.remaining()); + inputBuffer.position(inputBuffer.position() + (int) toSkip); + payloadWritten += toSkip; + if (payloadWritten == payloadLength) { + if (continuationExpected) { + newFrame(); + } else { + newMessage(); + } + return true; + } else { + return false; + } + } + + + protected static long byteArrayToLong(byte[] b, int start, int len) throws IOException { + if (len > 8) { + throw new IOException(sm.getString("wsFrame.byteToLongFail", Long.valueOf(len))); + } + int shift = 0; + long result = 0; + for (int i = start + len - 1; i >= start; i--) { + result = result + ((b[i] & 0xFF) << shift); + shift += 8; + } + return result; + } + + + protected boolean isOpen() { + return open; + } + + + protected Transformation getTransformation() { + return transformation; + } + + + private enum State { + NEW_FRAME, PARTIAL_HEADER, DATA + } + + + /** + * WAITING - not suspended + * Server case: waiting for a notification that data + * is ready to be read from the socket, the socket is + * registered to the poller + * Client case: data has been read from the socket and + * is waiting for data to be processed + * PROCESSING - not suspended + * Server case: reading from the socket and processing + * the data + * Client case: processing the data if such has + * already been read and more data will be read from + * the socket + * SUSPENDING_WAIT - suspended, a call to suspend() was made while in + * WAITING state. A call to resume() will do nothing + * and will transition to WAITING state + * SUSPENDING_PROCESS - suspended, a call to suspend() was made while in + * PROCESSING state. A call to resume() will do + * nothing and will transition to PROCESSING state + * SUSPENDED - suspended + * Server case: processing data finished + * (SUSPENDING_PROCESS) / a notification was received + * that data is ready to be read from the socket + * (SUSPENDING_WAIT), socket is not registered to the + * poller + * Client case: processing data finished + * (SUSPENDING_PROCESS) / data has been read from the + * socket and is available for processing + * (SUSPENDING_WAIT) + * A call to resume() will: + * Server case: register the socket to the poller + * Client case: resume data processing + * CLOSING - not suspended, a close will be send + * + * <pre> + * resume data to be resume + * no action processed no action + * |---------------| |---------------| |----------| + * | v | v v | + * | |----------WAITING --------PROCESSING----| | + * | | ^ processing | | + * | | | finished | | + * | | | | | + * | suspend | suspend | + * | | | | | + * | | resume | | + * | | register socket to poller (server) | | + * | | resume data processing (client) | | + * | | | | | + * | v | v | + * SUSPENDING_WAIT | SUSPENDING_PROCESS + * | | | + * | data available | processing finished | + * |------------- SUSPENDED ----------------------| + * </pre> + */ + protected enum ReadState { + WAITING (false), + PROCESSING (false), + SUSPENDING_WAIT (true), + SUSPENDING_PROCESS(true), + SUSPENDED (true), + CLOSING (false); + + private final boolean isSuspended; + + ReadState(boolean isSuspended) { + this.isSuspended = isSuspended; + } + + public boolean isSuspended() { + return isSuspended; + } + } + + public void suspend() { + while (true) { + switch (readState) { + case WAITING: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.WAITING, + ReadState.SUSPENDING_WAIT)) { + continue; + } + return; + case PROCESSING: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.PROCESSING, + ReadState.SUSPENDING_PROCESS)) { + continue; + } + return; + case SUSPENDING_WAIT: + if (readState != ReadState.SUSPENDING_WAIT) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.suspendRequested")); + } + } + return; + case SUSPENDING_PROCESS: + if (readState != ReadState.SUSPENDING_PROCESS) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.suspendRequested")); + } + } + return; + case SUSPENDED: + if (readState != ReadState.SUSPENDED) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadySuspended")); + } + } + return; + case CLOSING: + return; + default: + throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state)); + } + } + } + + public void resume() { + while (true) { + switch (readState) { + case WAITING: + if (readState != ReadState.WAITING) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadyResumed")); + } + } + return; + case PROCESSING: + if (readState != ReadState.PROCESSING) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadyResumed")); + } + } + return; + case SUSPENDING_WAIT: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_WAIT, + ReadState.WAITING)) { + continue; + } + return; + case SUSPENDING_PROCESS: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_PROCESS, + ReadState.PROCESSING)) { + continue; + } + return; + case SUSPENDED: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDED, + ReadState.WAITING)) { + continue; + } + resumeProcessing(); + return; + case CLOSING: + return; + default: + throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state)); + } + } + } + + protected boolean isSuspended() { + return readState.isSuspended(); + } + + protected ReadState getReadState() { + return readState; + } + + protected void changeReadState(ReadState newState) { + READ_STATE_UPDATER.set(this, newState); + } + + protected boolean changeReadState(ReadState oldState, ReadState newState) { + return READ_STATE_UPDATER.compareAndSet(this, oldState, newState); + } + + /** + * This method will be invoked when the read operation is resumed. + * As the suspend of the read operation can be invoked at any time, when + * implementing this method one should consider that there might still be + * data remaining into the internal buffers that needs to be processed + * before reading again from the socket. + */ + protected abstract void resumeProcessing(); + + + private abstract class TerminalTransformation implements Transformation { + + @Override + public boolean validateRsvBits(int i) { + // Terminal transformations don't use RSV bits and there is no next + // transformation so always return true. + return true; + } + + @Override + public Extension getExtensionResponse() { + // Return null since terminal transformations are not extensions + return null; + } + + @Override + public void setNext(Transformation t) { + // NO-OP since this is the terminal transformation + } + + /** + * {@inheritDoc} + * <p> + * Anything other than a value of zero for rsv is invalid. + */ + @Override + public boolean validateRsv(int rsv, byte opCode) { + return rsv == 0; + } + + @Override + public void close() { + // NO-OP for the terminal transformations + } + } + + + /** + * For use by the client implementation that needs to obtain payload data + * without the need for unmasking. + */ + private final class NoopTransformation extends TerminalTransformation { + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, + ByteBuffer dest) { + // opCode is ignored as the transformation is the same for all + // opCodes + // rsv is ignored as it known to be zero at this point + long toWrite = Math.min(payloadLength - payloadWritten, inputBuffer.remaining()); + toWrite = Math.min(toWrite, dest.remaining()); + + int orgLimit = inputBuffer.limit(); + inputBuffer.limit(inputBuffer.position() + (int) toWrite); + dest.put(inputBuffer); + inputBuffer.limit(orgLimit); + payloadWritten += toWrite; + + if (payloadWritten == payloadLength) { + return TransformationResult.END_OF_FRAME; + } else if (inputBuffer.remaining() == 0) { + return TransformationResult.UNDERFLOW; + } else { + // !dest.hasRemaining() + return TransformationResult.OVERFLOW; + } + } + + + @Override + public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) { + // TODO Masking should move to this method + // NO-OP send so simply return the message unchanged. + return messageParts; + } + } + + + /** + * For use by the server implementation that needs to obtain payload data + * and unmask it before any further processing. + */ + private final class UnmaskTransformation extends TerminalTransformation { + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, + ByteBuffer dest) { + // opCode is ignored as the transformation is the same for all + // opCodes + // rsv is ignored as it known to be zero at this point + while (payloadWritten < payloadLength && inputBuffer.remaining() > 0 && + dest.hasRemaining()) { + byte b = (byte) ((inputBuffer.get() ^ mask[maskIndex]) & 0xFF); + maskIndex++; + if (maskIndex == 4) { + maskIndex = 0; + } + payloadWritten++; + dest.put(b); + } + if (payloadWritten == payloadLength) { + return TransformationResult.END_OF_FRAME; + } else if (inputBuffer.remaining() == 0) { + return TransformationResult.UNDERFLOW; + } else { + // !dest.hasRemaining() + return TransformationResult.OVERFLOW; + } + } + + @Override + public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) { + // NO-OP send so simply return the message unchanged. + return messageParts; + } + } +} diff --git a/src/java/nginx/unit/websocket/WsFrameClient.java b/src/java/nginx/unit/websocket/WsFrameClient.java new file mode 100644 index 00000000..3174c766 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsFrameClient.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +public class WsFrameClient extends WsFrameBase { + + private final Log log = LogFactory.getLog(WsFrameClient.class); // must not be static + private static final StringManager sm = StringManager.getManager(WsFrameClient.class); + + private final AsyncChannelWrapper channel; + private final CompletionHandler<Integer, Void> handler; + // Not final as it may need to be re-sized + private volatile ByteBuffer response; + + public WsFrameClient(ByteBuffer response, AsyncChannelWrapper channel, WsSession wsSession, + Transformation transformation) { + super(wsSession, transformation); + this.response = response; + this.channel = channel; + this.handler = new WsFrameClientCompletionHandler(); + } + + + void startInputProcessing() { + try { + processSocketRead(); + } catch (IOException e) { + close(e); + } + } + + + private void processSocketRead() throws IOException { + while (true) { + switch (getReadState()) { + case WAITING: + if (!changeReadState(ReadState.WAITING, ReadState.PROCESSING)) { + continue; + } + while (response.hasRemaining()) { + if (isSuspended()) { + if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) { + continue; + } + // There is still data available in the response buffer + // Return here so that the response buffer will not be + // cleared and there will be no data read from the + // socket. Thus when the read operation is resumed first + // the data left in the response buffer will be consumed + // and then a new socket read will be performed + return; + } + inputBuffer.mark(); + inputBuffer.position(inputBuffer.limit()).limit(inputBuffer.capacity()); + + int toCopy = Math.min(response.remaining(), inputBuffer.remaining()); + + // Copy remaining bytes read in HTTP phase to input buffer used by + // frame processing + + int orgLimit = response.limit(); + response.limit(response.position() + toCopy); + inputBuffer.put(response); + response.limit(orgLimit); + + inputBuffer.limit(inputBuffer.position()).reset(); + + // Process the data we have + processInputBuffer(); + } + response.clear(); + + // Get some more data + if (isOpen()) { + channel.read(response, null, handler); + } else { + changeReadState(ReadState.CLOSING); + } + return; + case SUSPENDING_WAIT: + if (!changeReadState(ReadState.SUSPENDING_WAIT, ReadState.SUSPENDED)) { + continue; + } + return; + default: + throw new IllegalStateException( + sm.getString("wsFrameServer.illegalReadState", getReadState())); + } + } + } + + + private final void close(Throwable t) { + changeReadState(ReadState.CLOSING); + CloseReason cr; + if (t instanceof WsIOException) { + cr = ((WsIOException) t).getCloseReason(); + } else { + cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, t.getMessage()); + } + + try { + wsSession.close(cr); + } catch (IOException ignore) { + // Ignore + } + } + + + @Override + protected boolean isMasked() { + // Data is from the server so it is not masked + return false; + } + + + @Override + protected Log getLog() { + return log; + } + + private class WsFrameClientCompletionHandler implements CompletionHandler<Integer, Void> { + + @Override + public void completed(Integer result, Void attachment) { + if (result.intValue() == -1) { + // BZ 57762. A dropped connection will get reported as EOF + // rather than as an error so handle it here. + if (isOpen()) { + // No close frame was received + close(new EOFException()); + } + // No data to process + return; + } + response.flip(); + doResumeProcessing(true); + } + + @Override + public void failed(Throwable exc, Void attachment) { + if (exc instanceof ReadBufferOverflowException) { + // response will be empty if this exception is thrown + response = ByteBuffer + .allocate(((ReadBufferOverflowException) exc).getMinBufferSize()); + response.flip(); + doResumeProcessing(false); + } else { + close(exc); + } + } + + private void doResumeProcessing(boolean checkOpenOnError) { + while (true) { + switch (getReadState()) { + case PROCESSING: + if (!changeReadState(ReadState.PROCESSING, ReadState.WAITING)) { + continue; + } + resumeProcessing(checkOpenOnError); + return; + case SUSPENDING_PROCESS: + if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) { + continue; + } + return; + default: + throw new IllegalStateException( + sm.getString("wsFrame.illegalReadState", getReadState())); + } + } + } + } + + + @Override + protected void resumeProcessing() { + resumeProcessing(true); + } + + private void resumeProcessing(boolean checkOpenOnError) { + try { + processSocketRead(); + } catch (IOException e) { + if (checkOpenOnError) { + // Only send a close message on an IOException if the client + // has not yet received a close control message from the server + // as the IOException may be in response to the client + // continuing to send a message after the server sent a close + // control message. + if (isOpen()) { + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsFrameClient.ioe"), e); + } + close(e); + } + } else { + close(e); + } + } + } +} diff --git a/src/java/nginx/unit/websocket/WsHandshakeResponse.java b/src/java/nginx/unit/websocket/WsHandshakeResponse.java new file mode 100644 index 00000000..6e57ffd5 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsHandshakeResponse.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.websocket.HandshakeResponse; + +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; + +/** + * Represents the response to a WebSocket handshake. + */ +public class WsHandshakeResponse implements HandshakeResponse { + + private final Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>(); + + + public WsHandshakeResponse() { + } + + + public WsHandshakeResponse(Map<String,List<String>> headers) { + for (Entry<String,List<String>> entry : headers.entrySet()) { + if (this.headers.containsKey(entry.getKey())) { + this.headers.get(entry.getKey()).addAll(entry.getValue()); + } else { + List<String> values = new ArrayList<>(entry.getValue()); + this.headers.put(entry.getKey(), values); + } + } + } + + + @Override + public Map<String,List<String>> getHeaders() { + return headers; + } +} diff --git a/src/java/nginx/unit/websocket/WsIOException.java b/src/java/nginx/unit/websocket/WsIOException.java new file mode 100644 index 00000000..0362dc1d --- /dev/null +++ b/src/java/nginx/unit/websocket/WsIOException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; + +import javax.websocket.CloseReason; + +/** + * Allows the WebSocket implementation to throw an {@link IOException} that + * includes a {@link CloseReason} specific to the error that can be passed back + * to the client. + */ +public class WsIOException extends IOException { + + private static final long serialVersionUID = 1L; + + private final CloseReason closeReason; + + public WsIOException(CloseReason closeReason) { + this.closeReason = closeReason; + } + + public CloseReason getCloseReason() { + return closeReason; + } +} diff --git a/src/java/nginx/unit/websocket/WsPongMessage.java b/src/java/nginx/unit/websocket/WsPongMessage.java new file mode 100644 index 00000000..531bcda9 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsPongMessage.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; + +import javax.websocket.PongMessage; + +public class WsPongMessage implements PongMessage { + + private final ByteBuffer applicationData; + + + public WsPongMessage(ByteBuffer applicationData) { + byte[] dst = new byte[applicationData.limit()]; + applicationData.get(dst); + this.applicationData = ByteBuffer.wrap(dst); + } + + + @Override + public ByteBuffer getApplicationData() { + return applicationData; + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java b/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java new file mode 100644 index 00000000..0ea20795 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; +import java.util.concurrent.Future; + +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; + +public class WsRemoteEndpointAsync extends WsRemoteEndpointBase + implements RemoteEndpoint.Async { + + WsRemoteEndpointAsync(WsRemoteEndpointImplBase base) { + super(base); + } + + + @Override + public long getSendTimeout() { + return base.getSendTimeout(); + } + + + @Override + public void setSendTimeout(long timeout) { + base.setSendTimeout(timeout); + } + + + @Override + public void sendText(String text, SendHandler completion) { + base.sendStringByCompletion(text, completion); + } + + + @Override + public Future<Void> sendText(String text) { + return base.sendStringByFuture(text); + } + + + @Override + public Future<Void> sendBinary(ByteBuffer data) { + return base.sendBytesByFuture(data); + } + + + @Override + public void sendBinary(ByteBuffer data, SendHandler completion) { + base.sendBytesByCompletion(data, completion); + } + + + @Override + public Future<Void> sendObject(Object obj) { + return base.sendObjectByFuture(obj); + } + + + @Override + public void sendObject(Object obj, SendHandler completion) { + base.sendObjectByCompletion(obj, completion); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java new file mode 100644 index 00000000..21cb2040 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import javax.websocket.RemoteEndpoint; + +public abstract class WsRemoteEndpointBase implements RemoteEndpoint { + + protected final WsRemoteEndpointImplBase base; + + + WsRemoteEndpointBase(WsRemoteEndpointImplBase base) { + this.base = base; + } + + + @Override + public final void setBatchingAllowed(boolean batchingAllowed) throws IOException { + base.setBatchingAllowed(batchingAllowed); + } + + + @Override + public final boolean getBatchingAllowed() { + return base.getBatchingAllowed(); + } + + + @Override + public final void flushBatch() throws IOException { + base.flushBatch(); + } + + + @Override + public final void sendPing(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + base.sendPing(applicationData); + } + + + @Override + public final void sendPong(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + base.sendPong(applicationData); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java b/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java new file mode 100644 index 00000000..2a93cc7b --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.nio.ByteBuffer; + +import javax.websocket.EncodeException; +import javax.websocket.RemoteEndpoint; + +public class WsRemoteEndpointBasic extends WsRemoteEndpointBase + implements RemoteEndpoint.Basic { + + WsRemoteEndpointBasic(WsRemoteEndpointImplBase base) { + super(base); + } + + + @Override + public void sendText(String text) throws IOException { + base.sendString(text); + } + + + @Override + public void sendBinary(ByteBuffer data) throws IOException { + base.sendBytes(data); + } + + + @Override + public void sendText(String fragment, boolean isLast) throws IOException { + base.sendPartialString(fragment, isLast); + } + + + @Override + public void sendBinary(ByteBuffer partialByte, boolean isLast) + throws IOException { + base.sendPartialBytes(partialByte, isLast); + } + + + @Override + public OutputStream getSendStream() throws IOException { + return base.getSendStream(); + } + + + @Override + public Writer getSendWriter() throws IOException { + return base.getSendWriter(); + } + + + @Override + public void sendObject(Object o) throws IOException, EncodeException { + base.sendObject(o); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java new file mode 100644 index 00000000..776124fd --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java @@ -0,0 +1,1234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CoderResult; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.EncodeException; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.buf.Utf8Encoder; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.Request; + +public abstract class WsRemoteEndpointImplBase implements RemoteEndpoint { + + private static final StringManager sm = + StringManager.getManager(WsRemoteEndpointImplBase.class); + + protected static final SendResult SENDRESULT_OK = new SendResult(); + + private final Log log = LogFactory.getLog(WsRemoteEndpointImplBase.class); // must not be static + + private final StateMachine stateMachine = new StateMachine(); + + private final IntermediateMessageHandler intermediateMessageHandler = + new IntermediateMessageHandler(this); + + private Transformation transformation = null; + private final Semaphore messagePartInProgress = new Semaphore(1); + private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>(); + private final Object messagePartLock = new Object(); + + // State + private volatile boolean closed = false; + private boolean fragmented = false; + private boolean nextFragmented = false; + private boolean text = false; + private boolean nextText = false; + + // Max size of WebSocket header is 14 bytes + private final ByteBuffer headerBuffer = ByteBuffer.allocate(14); + private final ByteBuffer outputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final CharsetEncoder encoder = new Utf8Encoder(); + private final ByteBuffer encoderBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final AtomicBoolean batchingAllowed = new AtomicBoolean(false); + private volatile long sendTimeout = -1; + private WsSession wsSession; + private List<EncoderEntry> encoderEntries = new ArrayList<>(); + + private Request request; + + + protected void setTransformation(Transformation transformation) { + this.transformation = transformation; + } + + + public long getSendTimeout() { + return sendTimeout; + } + + + public void setSendTimeout(long timeout) { + this.sendTimeout = timeout; + } + + + @Override + public void setBatchingAllowed(boolean batchingAllowed) throws IOException { + boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed); + + if (oldValue && !batchingAllowed) { + flushBatch(); + } + } + + + @Override + public boolean getBatchingAllowed() { + return batchingAllowed.get(); + } + + + @Override + public void flushBatch() throws IOException { + sendMessageBlock(Constants.INTERNAL_OPCODE_FLUSH, null, true); + } + + + public void sendBytes(ByteBuffer data) throws IOException { + if (data == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.binaryStart(); + sendMessageBlock(Constants.OPCODE_BINARY, data, true); + stateMachine.complete(true); + } + + + public Future<Void> sendBytesByFuture(ByteBuffer data) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendBytesByCompletion(data, f2sh); + return f2sh; + } + + + public void sendBytesByCompletion(ByteBuffer data, SendHandler handler) { + if (data == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (handler == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + StateUpdateSendHandler sush = new StateUpdateSendHandler(handler, stateMachine); + stateMachine.binaryStart(); + startMessage(Constants.OPCODE_BINARY, data, true, sush); + } + + + public void sendPartialBytes(ByteBuffer partialByte, boolean last) + throws IOException { + if (partialByte == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.binaryPartialStart(); + sendMessageBlock(Constants.OPCODE_BINARY, partialByte, last); + stateMachine.complete(last); + } + + + @Override + public void sendPing(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + if (applicationData.remaining() > 125) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); + } + sendMessageBlock(Constants.OPCODE_PING, applicationData, true); + } + + + @Override + public void sendPong(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + if (applicationData.remaining() > 125) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); + } + sendMessageBlock(Constants.OPCODE_PONG, applicationData, true); + } + + + public void sendString(String text) throws IOException { + if (text == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.textStart(); + sendMessageBlock(CharBuffer.wrap(text), true); + } + + + public Future<Void> sendStringByFuture(String text) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendStringByCompletion(text, f2sh); + return f2sh; + } + + + public void sendStringByCompletion(String text, SendHandler handler) { + if (text == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (handler == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + stateMachine.textStart(); + TextMessageSendHandler tmsh = new TextMessageSendHandler(handler, + CharBuffer.wrap(text), true, encoder, encoderBuffer, this); + tmsh.write(); + // TextMessageSendHandler will update stateMachine when it completes + } + + + public void sendPartialString(String fragment, boolean isLast) + throws IOException { + if (fragment == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.textPartialStart(); + sendMessageBlock(CharBuffer.wrap(fragment), isLast); + } + + + public OutputStream getSendStream() { + stateMachine.streamStart(); + return new WsOutputStream(this); + } + + + public Writer getSendWriter() { + stateMachine.writeStart(); + return new WsWriter(this); + } + + + void sendMessageBlock(CharBuffer part, boolean last) throws IOException { + long timeoutExpiry = getTimeoutExpiry(); + boolean isDone = false; + while (!isDone) { + encoderBuffer.clear(); + CoderResult cr = encoder.encode(part, encoderBuffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + encoderBuffer.flip(); + sendMessageBlock(Constants.OPCODE_TEXT, encoderBuffer, last && isDone, timeoutExpiry); + } + stateMachine.complete(last); + } + + + void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last) + throws IOException { + sendMessageBlock(opCode, payload, last, getTimeoutExpiry()); + } + + + private long getTimeoutExpiry() { + // Get the timeout before we send the message. The message may + // trigger a session close and depending on timing the client + // session may close before we can read the timeout. + long timeout = getBlockingSendTimeout(); + if (timeout < 0) { + return Long.MAX_VALUE; + } else { + return System.currentTimeMillis() + timeout; + } + } + + private byte currentOpCode = Constants.OPCODE_CONTINUATION; + + private void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last, + long timeoutExpiry) throws IOException { + wsSession.updateLastActive(); + + if (opCode == currentOpCode) { + opCode = Constants.OPCODE_CONTINUATION; + } + + request.sendWsFrame(payload, opCode, last, timeoutExpiry); + + if (!last && opCode != Constants.OPCODE_CONTINUATION) { + currentOpCode = opCode; + } + + if (last && opCode == Constants.OPCODE_CONTINUATION) { + currentOpCode = Constants.OPCODE_CONTINUATION; + } + } + + + void startMessage(byte opCode, ByteBuffer payload, boolean last, + SendHandler handler) { + + wsSession.updateLastActive(); + + List<MessagePart> messageParts = new ArrayList<>(); + messageParts.add(new MessagePart(last, 0, opCode, payload, + intermediateMessageHandler, + new EndMessageHandler(this, handler), -1)); + + messageParts = transformation.sendMessagePart(messageParts); + + // Some extensions/transformations may buffer messages so it is possible + // that no message parts will be returned. If this is the case the + // trigger the supplied SendHandler + if (messageParts.size() == 0) { + handler.onResult(new SendResult()); + return; + } + + MessagePart mp = messageParts.remove(0); + + boolean doWrite = false; + synchronized (messagePartLock) { + if (Constants.OPCODE_CLOSE == mp.getOpCode() && getBatchingAllowed()) { + // Should not happen. To late to send batched messages now since + // the session has been closed. Complain loudly. + log.warn(sm.getString("wsRemoteEndpoint.flushOnCloseFailed")); + } + if (messagePartInProgress.tryAcquire()) { + doWrite = true; + } else { + // When a control message is sent while another message is being + // sent, the control message is queued. Chances are the + // subsequent data message part will end up queued while the + // control message is sent. The logic in this class (state + // machine, EndMessageHandler, TextMessageSendHandler) ensures + // that there will only ever be one data message part in the + // queue. There could be multiple control messages in the queue. + + // Add it to the queue + messagePartQueue.add(mp); + } + // Add any remaining messages to the queue + messagePartQueue.addAll(messageParts); + } + if (doWrite) { + // Actual write has to be outside sync block to avoid possible + // deadlock between messagePartLock and writeLock in + // o.a.coyote.http11.upgrade.AbstractServletOutputStream + writeMessagePart(mp); + } + } + + + void endMessage(SendHandler handler, SendResult result) { + boolean doWrite = false; + MessagePart mpNext = null; + synchronized (messagePartLock) { + + fragmented = nextFragmented; + text = nextText; + + mpNext = messagePartQueue.poll(); + if (mpNext == null) { + messagePartInProgress.release(); + } else if (!closed){ + // Session may have been closed unexpectedly in the middle of + // sending a fragmented message closing the endpoint. If this + // happens, clearly there is no point trying to send the rest of + // the message. + doWrite = true; + } + } + if (doWrite) { + // Actual write has to be outside sync block to avoid possible + // deadlock between messagePartLock and writeLock in + // o.a.coyote.http11.upgrade.AbstractServletOutputStream + writeMessagePart(mpNext); + } + + wsSession.updateLastActive(); + + // Some handlers, such as the IntermediateMessageHandler, do not have a + // nested handler so handler may be null. + if (handler != null) { + handler.onResult(result); + } + } + + + void writeMessagePart(MessagePart mp) { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closed")); + } + + if (Constants.INTERNAL_OPCODE_FLUSH == mp.getOpCode()) { + nextFragmented = fragmented; + nextText = text; + outputBuffer.flip(); + SendHandler flushHandler = new OutputBufferFlushSendHandler( + outputBuffer, mp.getEndHandler()); + doWrite(flushHandler, mp.getBlockingWriteTimeoutExpiry(), outputBuffer); + return; + } + + // Control messages may be sent in the middle of fragmented message + // so they have no effect on the fragmented or text flags + boolean first; + if (Util.isControl(mp.getOpCode())) { + nextFragmented = fragmented; + nextText = text; + if (mp.getOpCode() == Constants.OPCODE_CLOSE) { + closed = true; + } + first = true; + } else { + boolean isText = Util.isText(mp.getOpCode()); + + if (fragmented) { + // Currently fragmented + if (text != isText) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.changeType")); + } + nextText = text; + nextFragmented = !mp.isFin(); + first = false; + } else { + // Wasn't fragmented. Might be now + if (mp.isFin()) { + nextFragmented = false; + } else { + nextFragmented = true; + nextText = isText; + } + first = true; + } + } + + byte[] mask; + + if (isMasked()) { + mask = Util.generateMask(); + } else { + mask = null; + } + + headerBuffer.clear(); + writeHeader(headerBuffer, mp.isFin(), mp.getRsv(), mp.getOpCode(), + isMasked(), mp.getPayload(), mask, first); + headerBuffer.flip(); + + if (getBatchingAllowed() || isMasked()) { + // Need to write via output buffer + OutputBufferSendHandler obsh = new OutputBufferSendHandler( + mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), + headerBuffer, mp.getPayload(), mask, + outputBuffer, !getBatchingAllowed(), this); + obsh.write(); + } else { + // Can write directly + doWrite(mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), + headerBuffer, mp.getPayload()); + } + } + + + private long getBlockingSendTimeout() { + Object obj = wsSession.getUserProperties().get(Constants.BLOCKING_SEND_TIMEOUT_PROPERTY); + Long userTimeout = null; + if (obj instanceof Long) { + userTimeout = (Long) obj; + } + if (userTimeout == null) { + return Constants.DEFAULT_BLOCKING_SEND_TIMEOUT; + } else { + return userTimeout.longValue(); + } + } + + + /** + * Wraps the user provided handler so that the end point is notified when + * the message is complete. + */ + private static class EndMessageHandler implements SendHandler { + + private final WsRemoteEndpointImplBase endpoint; + private final SendHandler handler; + + public EndMessageHandler(WsRemoteEndpointImplBase endpoint, + SendHandler handler) { + this.endpoint = endpoint; + this.handler = handler; + } + + + @Override + public void onResult(SendResult result) { + endpoint.endMessage(handler, result); + } + } + + + /** + * If a transformation needs to split a {@link MessagePart} into multiple + * {@link MessagePart}s, it uses this handler as the end handler for each of + * the additional {@link MessagePart}s. This handler notifies this this + * class that the {@link MessagePart} has been processed and that the next + * {@link MessagePart} in the queue should be started. The final + * {@link MessagePart} will use the {@link EndMessageHandler} provided with + * the original {@link MessagePart}. + */ + private static class IntermediateMessageHandler implements SendHandler { + + private final WsRemoteEndpointImplBase endpoint; + + public IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + + @Override + public void onResult(SendResult result) { + endpoint.endMessage(null, result); + } + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void sendObject(Object obj) throws IOException, EncodeException { + if (obj == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + /* + * Note that the implementation will convert primitives and their object + * equivalents by default but that users are free to specify their own + * encoders and decoders for this if they wish. + */ + Encoder encoder = findEncoder(obj); + if (encoder == null && Util.isPrimitive(obj.getClass())) { + String msg = obj.toString(); + sendString(msg); + return; + } + if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { + ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); + sendBytes(msg); + return; + } + + if (encoder instanceof Encoder.Text) { + String msg = ((Encoder.Text) encoder).encode(obj); + sendString(msg); + } else if (encoder instanceof Encoder.TextStream) { + try (Writer w = getSendWriter()) { + ((Encoder.TextStream) encoder).encode(obj, w); + } + } else if (encoder instanceof Encoder.Binary) { + ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); + sendBytes(msg); + } else if (encoder instanceof Encoder.BinaryStream) { + try (OutputStream os = getSendStream()) { + ((Encoder.BinaryStream) encoder).encode(obj, os); + } + } else { + throw new EncodeException(obj, sm.getString( + "wsRemoteEndpoint.noEncoder", obj.getClass())); + } + } + + + public Future<Void> sendObjectByFuture(Object obj) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendObjectByCompletion(obj, f2sh); + return f2sh; + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void sendObjectByCompletion(Object obj, SendHandler completion) { + + if (obj == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (completion == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + + /* + * Note that the implementation will convert primitives and their object + * equivalents by default but that users are free to specify their own + * encoders and decoders for this if they wish. + */ + Encoder encoder = findEncoder(obj); + if (encoder == null && Util.isPrimitive(obj.getClass())) { + String msg = obj.toString(); + sendStringByCompletion(msg, completion); + return; + } + if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { + ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); + sendBytesByCompletion(msg, completion); + return; + } + + try { + if (encoder instanceof Encoder.Text) { + String msg = ((Encoder.Text) encoder).encode(obj); + sendStringByCompletion(msg, completion); + } else if (encoder instanceof Encoder.TextStream) { + try (Writer w = getSendWriter()) { + ((Encoder.TextStream) encoder).encode(obj, w); + } + completion.onResult(new SendResult()); + } else if (encoder instanceof Encoder.Binary) { + ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); + sendBytesByCompletion(msg, completion); + } else if (encoder instanceof Encoder.BinaryStream) { + try (OutputStream os = getSendStream()) { + ((Encoder.BinaryStream) encoder).encode(obj, os); + } + completion.onResult(new SendResult()); + } else { + throw new EncodeException(obj, sm.getString( + "wsRemoteEndpoint.noEncoder", obj.getClass())); + } + } catch (Exception e) { + SendResult sr = new SendResult(e); + completion.onResult(sr); + } + } + + + protected void setSession(WsSession wsSession) { + this.wsSession = wsSession; + } + + + protected void setRequest(Request request) { + this.request = request; + } + + protected void setEncoders(EndpointConfig endpointConfig) + throws DeploymentException { + encoderEntries.clear(); + for (Class<? extends Encoder> encoderClazz : + endpointConfig.getEncoders()) { + Encoder instance; + try { + instance = encoderClazz.getConstructor().newInstance(); + instance.init(endpointConfig); + } catch (ReflectiveOperationException e) { + throw new DeploymentException( + sm.getString("wsRemoteEndpoint.invalidEncoder", + encoderClazz.getName()), e); + } + EncoderEntry entry = new EncoderEntry( + Util.getEncoderType(encoderClazz), instance); + encoderEntries.add(entry); + } + } + + + private Encoder findEncoder(Object obj) { + for (EncoderEntry entry : encoderEntries) { + if (entry.getClazz().isAssignableFrom(obj.getClass())) { + return entry.getEncoder(); + } + } + return null; + } + + + public final void close() { + for (EncoderEntry entry : encoderEntries) { + entry.getEncoder().destroy(); + } + + request.closeWs(); + } + + + protected abstract void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... data); + protected abstract boolean isMasked(); + protected abstract void doClose(); + + private static void writeHeader(ByteBuffer headerBuffer, boolean fin, + int rsv, byte opCode, boolean masked, ByteBuffer payload, + byte[] mask, boolean first) { + + byte b = 0; + + if (fin) { + // Set the fin bit + b -= 128; + } + + b += (rsv << 4); + + if (first) { + // This is the first fragment of this message + b += opCode; + } + // If not the first fragment, it is a continuation with opCode of zero + + headerBuffer.put(b); + + if (masked) { + b = (byte) 0x80; + } else { + b = 0; + } + + // Next write the mask && length length + if (payload.limit() < 126) { + headerBuffer.put((byte) (payload.limit() | b)); + } else if (payload.limit() < 65536) { + headerBuffer.put((byte) (126 | b)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } else { + // Will never be more than 2^31-1 + headerBuffer.put((byte) (127 | b)); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) (payload.limit() >>> 24)); + headerBuffer.put((byte) (payload.limit() >>> 16)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } + if (masked) { + headerBuffer.put(mask[0]); + headerBuffer.put(mask[1]); + headerBuffer.put(mask[2]); + headerBuffer.put(mask[3]); + } + } + + + private class TextMessageSendHandler implements SendHandler { + + private final SendHandler handler; + private final CharBuffer message; + private final boolean isLast; + private final CharsetEncoder encoder; + private final ByteBuffer buffer; + private final WsRemoteEndpointImplBase endpoint; + private volatile boolean isDone = false; + + public TextMessageSendHandler(SendHandler handler, CharBuffer message, + boolean isLast, CharsetEncoder encoder, + ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint) { + this.handler = handler; + this.message = message; + this.isLast = isLast; + this.encoder = encoder.reset(); + this.buffer = encoderBuffer; + this.endpoint = endpoint; + } + + public void write() { + buffer.clear(); + CoderResult cr = encoder.encode(message, buffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + buffer.flip(); + endpoint.startMessage(Constants.OPCODE_TEXT, buffer, + isDone && isLast, this); + } + + @Override + public void onResult(SendResult result) { + if (isDone) { + endpoint.stateMachine.complete(isLast); + handler.onResult(result); + } else if(!result.isOK()) { + handler.onResult(result); + } else if (closed){ + SendResult sr = new SendResult(new IOException( + sm.getString("wsRemoteEndpoint.closedDuringMessage"))); + handler.onResult(sr); + } else { + write(); + } + } + } + + + /** + * Used to write data to the output buffer, flushing the buffer if it fills + * up. + */ + private static class OutputBufferSendHandler implements SendHandler { + + private final SendHandler handler; + private final long blockingWriteTimeoutExpiry; + private final ByteBuffer headerBuffer; + private final ByteBuffer payload; + private final byte[] mask; + private final ByteBuffer outputBuffer; + private final boolean flushRequired; + private final WsRemoteEndpointImplBase endpoint; + private int maskIndex = 0; + + public OutputBufferSendHandler(SendHandler completion, + long blockingWriteTimeoutExpiry, + ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask, + ByteBuffer outputBuffer, boolean flushRequired, + WsRemoteEndpointImplBase endpoint) { + this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry; + this.handler = completion; + this.headerBuffer = headerBuffer; + this.payload = payload; + this.mask = mask; + this.outputBuffer = outputBuffer; + this.flushRequired = flushRequired; + this.endpoint = endpoint; + } + + public void write() { + // Write the header + while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) { + outputBuffer.put(headerBuffer.get()); + } + if (headerBuffer.hasRemaining()) { + // Still more headers to write, need to flush + outputBuffer.flip(); + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + return; + } + + // Write the payload + int payloadLeft = payload.remaining(); + int payloadLimit = payload.limit(); + int outputSpace = outputBuffer.remaining(); + int toWrite = payloadLeft; + + if (payloadLeft > outputSpace) { + toWrite = outputSpace; + // Temporarily reduce the limit + payload.limit(payload.position() + toWrite); + } + + if (mask == null) { + // Use a bulk copy + outputBuffer.put(payload); + } else { + for (int i = 0; i < toWrite; i++) { + outputBuffer.put( + (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF))); + if (maskIndex > 3) { + maskIndex = 0; + } + } + } + + if (payloadLeft > outputSpace) { + // Restore the original limit + payload.limit(payloadLimit); + // Still more data to write, need to flush + outputBuffer.flip(); + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + return; + } + + if (flushRequired) { + outputBuffer.flip(); + if (outputBuffer.remaining() == 0) { + handler.onResult(SENDRESULT_OK); + } else { + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + } + } else { + handler.onResult(SENDRESULT_OK); + } + } + + // ------------------------------------------------- SendHandler methods + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + if (outputBuffer.hasRemaining()) { + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + } else { + outputBuffer.clear(); + write(); + } + } else { + handler.onResult(result); + } + } + } + + + /** + * Ensures that the output buffer is cleared after it has been flushed. + */ + private static class OutputBufferFlushSendHandler implements SendHandler { + + private final ByteBuffer outputBuffer; + private final SendHandler handler; + + public OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler) { + this.outputBuffer = outputBuffer; + this.handler = handler; + } + + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + outputBuffer.clear(); + } + handler.onResult(result); + } + } + + + private static class WsOutputStream extends OutputStream { + + private final WsRemoteEndpointImplBase endpoint; + private final ByteBuffer buffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Object closeLock = new Object(); + private volatile boolean closed = false; + private volatile boolean used = false; + + public WsOutputStream(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + @Override + public void write(int b) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + buffer.put((byte) b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + if (len == 0) { + return; + } + if ((off < 0) || (off > b.length) || (len < 0) || + ((off + len) > b.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + int remaining = buffer.remaining(); + int written = 0; + + while (remaining < len - written) { + buffer.put(b, off + written, remaining); + written += remaining; + flush(); + remaining = buffer.remaining(); + } + buffer.put(b, off + written, len - written); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + + // Optimisation. If there is no data to flush then do not send an + // empty message. + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { + doWrite(false); + } + } + + @Override + public void close() throws IOException { + synchronized (closeLock) { + if (closed) { + return; + } + closed = true; + } + + doWrite(true); + } + + private void doWrite(boolean last) throws IOException { + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { + buffer.flip(); + endpoint.sendMessageBlock(Constants.OPCODE_BINARY, buffer, last); + } + endpoint.stateMachine.complete(last); + buffer.clear(); + } + } + + + private static class WsWriter extends Writer { + + private final WsRemoteEndpointImplBase endpoint; + private final CharBuffer buffer = CharBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Object closeLock = new Object(); + private volatile boolean closed = false; + private volatile boolean used = false; + + public WsWriter(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + @Override + public void write(char[] cbuf, int off, int len) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedWriter")); + } + if (len == 0) { + return; + } + if ((off < 0) || (off > cbuf.length) || (len < 0) || + ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + int remaining = buffer.remaining(); + int written = 0; + + while (remaining < len - written) { + buffer.put(cbuf, off + written, remaining); + written += remaining; + flush(); + remaining = buffer.remaining(); + } + buffer.put(cbuf, off + written, len - written); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedWriter")); + } + + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { + doWrite(false); + } + } + + @Override + public void close() throws IOException { + synchronized (closeLock) { + if (closed) { + return; + } + closed = true; + } + + doWrite(true); + } + + private void doWrite(boolean last) throws IOException { + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { + buffer.flip(); + endpoint.sendMessageBlock(buffer, last); + buffer.clear(); + } else { + endpoint.stateMachine.complete(last); + } + } + } + + + private static class EncoderEntry { + + private final Class<?> clazz; + private final Encoder encoder; + + public EncoderEntry(Class<?> clazz, Encoder encoder) { + this.clazz = clazz; + this.encoder = encoder; + } + + public Class<?> getClazz() { + return clazz; + } + + public Encoder getEncoder() { + return encoder; + } + } + + + private enum State { + OPEN, + STREAM_WRITING, + WRITER_WRITING, + BINARY_PARTIAL_WRITING, + BINARY_PARTIAL_READY, + BINARY_FULL_WRITING, + TEXT_PARTIAL_WRITING, + TEXT_PARTIAL_READY, + TEXT_FULL_WRITING + } + + + private static class StateMachine { + private State state = State.OPEN; + + public synchronized void streamStart() { + checkState(State.OPEN); + state = State.STREAM_WRITING; + } + + public synchronized void writeStart() { + checkState(State.OPEN); + state = State.WRITER_WRITING; + } + + public synchronized void binaryPartialStart() { + checkState(State.OPEN, State.BINARY_PARTIAL_READY); + state = State.BINARY_PARTIAL_WRITING; + } + + public synchronized void binaryStart() { + checkState(State.OPEN); + state = State.BINARY_FULL_WRITING; + } + + public synchronized void textPartialStart() { + checkState(State.OPEN, State.TEXT_PARTIAL_READY); + state = State.TEXT_PARTIAL_WRITING; + } + + public synchronized void textStart() { + checkState(State.OPEN); + state = State.TEXT_FULL_WRITING; + } + + public synchronized void complete(boolean last) { + if (last) { + checkState(State.TEXT_PARTIAL_WRITING, State.TEXT_FULL_WRITING, + State.BINARY_PARTIAL_WRITING, State.BINARY_FULL_WRITING, + State.STREAM_WRITING, State.WRITER_WRITING); + state = State.OPEN; + } else { + checkState(State.TEXT_PARTIAL_WRITING, State.BINARY_PARTIAL_WRITING, + State.STREAM_WRITING, State.WRITER_WRITING); + if (state == State.TEXT_PARTIAL_WRITING) { + state = State.TEXT_PARTIAL_READY; + } else if (state == State.BINARY_PARTIAL_WRITING){ + state = State.BINARY_PARTIAL_READY; + } else if (state == State.WRITER_WRITING) { + // NO-OP. Leave state as is. + } else if (state == State.STREAM_WRITING) { + // NO-OP. Leave state as is. + } else { + // Should never happen + // The if ... else ... blocks above should cover all states + // permitted by the preceding checkState() call + throw new IllegalStateException( + "BUG: This code should never be called"); + } + } + } + + private void checkState(State... required) { + for (State state : required) { + if (this.state == state) { + return; + } + } + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.wrongState", this.state)); + } + } + + + private static class StateUpdateSendHandler implements SendHandler { + + private final SendHandler handler; + private final StateMachine stateMachine; + + public StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine) { + this.handler = handler; + this.stateMachine = stateMachine; + } + + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + stateMachine.complete(true); + } + handler.onResult(result); + } + } + + + private static class BlockingSendHandler implements SendHandler { + + private SendResult sendResult = null; + + @Override + public void onResult(SendResult result) { + sendResult = result; + } + + public SendResult getSendResult() { + return sendResult; + } + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java new file mode 100644 index 00000000..70b66789 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +public class WsRemoteEndpointImplClient extends WsRemoteEndpointImplBase { + + private final AsyncChannelWrapper channel; + + public WsRemoteEndpointImplClient(AsyncChannelWrapper channel) { + this.channel = channel; + } + + + @Override + protected boolean isMasked() { + return true; + } + + + @Override + protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... data) { + long timeout; + for (ByteBuffer byteBuffer : data) { + if (blockingWriteTimeoutExpiry == -1) { + timeout = getSendTimeout(); + if (timeout < 1) { + timeout = Long.MAX_VALUE; + } + } else { + timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis(); + if (timeout < 0) { + SendResult sr = new SendResult(new IOException("Blocking write timeout")); + handler.onResult(sr); + } + } + + try { + channel.write(byteBuffer).get(timeout, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + handler.onResult(new SendResult(e)); + return; + } + } + handler.onResult(SENDRESULT_OK); + } + + @Override + protected void doClose() { + channel.close(); + } +} diff --git a/src/java/nginx/unit/websocket/WsSession.java b/src/java/nginx/unit/websocket/WsSession.java new file mode 100644 index 00000000..b654eb37 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsSession.java @@ -0,0 +1,1070 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.CharBuffer; +import java.nio.channels.WritePendingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCode; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.MessageHandler.Partial; +import javax.websocket.MessageHandler.Whole; +import javax.websocket.PongMessage; +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendResult; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.InstanceManagerBindings; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.buf.Utf8Decoder; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.Request; + +public class WsSession implements Session { + + // An ellipsis is a single character that looks like three periods in a row + // and is used to indicate a continuation. + private static final byte[] ELLIPSIS_BYTES = "\u2026".getBytes(StandardCharsets.UTF_8); + // An ellipsis is three bytes in UTF-8 + private static final int ELLIPSIS_BYTES_LEN = ELLIPSIS_BYTES.length; + + private static final StringManager sm = StringManager.getManager(WsSession.class); + private static AtomicLong ids = new AtomicLong(0); + + private final Log log = LogFactory.getLog(WsSession.class); // must not be static + + private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + + private final Endpoint localEndpoint; + private final WsRemoteEndpointImplBase wsRemoteEndpoint; + private final RemoteEndpoint.Async remoteEndpointAsync; + private final RemoteEndpoint.Basic remoteEndpointBasic; + private final ClassLoader applicationClassLoader; + private final WsWebSocketContainer webSocketContainer; + private final URI requestUri; + private final Map<String, List<String>> requestParameterMap; + private final String queryString; + private final Principal userPrincipal; + private final EndpointConfig endpointConfig; + + private final List<Extension> negotiatedExtensions; + private final String subProtocol; + private final Map<String, String> pathParameters; + private final boolean secure; + private final String httpSessionId; + private final String id; + + // Expected to handle message types of <String> only + private volatile MessageHandler textMessageHandler = null; + // Expected to handle message types of <ByteBuffer> only + private volatile MessageHandler binaryMessageHandler = null; + private volatile MessageHandler.Whole<PongMessage> pongMessageHandler = null; + private volatile State state = State.OPEN; + private final Object stateLock = new Object(); + private final Map<String, Object> userProperties = new ConcurrentHashMap<>(); + private volatile int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile long maxIdleTimeout = 0; + private volatile long lastActive = System.currentTimeMillis(); + private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>(); + + private CharBuffer messageBufferText; + private ByteBuffer binaryBuffer; + private byte startOpCode = Constants.OPCODE_CONTINUATION; + + /** + * Creates a new WebSocket session for communication between the two + * provided end points. The result of {@link Thread#getContextClassLoader()} + * at the time this constructor is called will be used when calling + * {@link Endpoint#onClose(Session, CloseReason)}. + * + * @param localEndpoint The end point managed by this code + * @param wsRemoteEndpoint The other / remote endpoint + * @param wsWebSocketContainer The container that created this session + * @param requestUri The URI used to connect to this endpoint or + * <code>null</code> is this is a client session + * @param requestParameterMap The parameters associated with the request + * that initiated this session or + * <code>null</code> if this is a client session + * @param queryString The query string associated with the request + * that initiated this session or + * <code>null</code> if this is a client session + * @param userPrincipal The principal associated with the request + * that initiated this session or + * <code>null</code> if this is a client session + * @param httpSessionId The HTTP session ID associated with the + * request that initiated this session or + * <code>null</code> if this is a client session + * @param negotiatedExtensions The agreed extensions to use for this session + * @param subProtocol The agreed subprotocol to use for this + * session + * @param pathParameters The path parameters associated with the + * request that initiated this session or + * <code>null</code> if this is a client session + * @param secure Was this session initiated over a secure + * connection? + * @param endpointConfig The configuration information for the + * endpoint + * @throws DeploymentException if an invalid encode is specified + */ + public WsSession(Endpoint localEndpoint, + WsRemoteEndpointImplBase wsRemoteEndpoint, + WsWebSocketContainer wsWebSocketContainer, + URI requestUri, Map<String, List<String>> requestParameterMap, + String queryString, Principal userPrincipal, String httpSessionId, + List<Extension> negotiatedExtensions, String subProtocol, Map<String, String> pathParameters, + boolean secure, EndpointConfig endpointConfig, + Request request) throws DeploymentException { + this.localEndpoint = localEndpoint; + this.wsRemoteEndpoint = wsRemoteEndpoint; + this.wsRemoteEndpoint.setSession(this); + this.wsRemoteEndpoint.setRequest(request); + + request.setWsSession(this); + + this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint); + this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint); + this.webSocketContainer = wsWebSocketContainer; + applicationClassLoader = Thread.currentThread().getContextClassLoader(); + wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout()); + this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize(); + this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize(); + this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout(); + this.requestUri = requestUri; + if (requestParameterMap == null) { + this.requestParameterMap = Collections.emptyMap(); + } else { + this.requestParameterMap = requestParameterMap; + } + this.queryString = queryString; + this.userPrincipal = userPrincipal; + this.httpSessionId = httpSessionId; + this.negotiatedExtensions = negotiatedExtensions; + if (subProtocol == null) { + this.subProtocol = ""; + } else { + this.subProtocol = subProtocol; + } + this.pathParameters = pathParameters; + this.secure = secure; + this.wsRemoteEndpoint.setEncoders(endpointConfig); + this.endpointConfig = endpointConfig; + + this.userProperties.putAll(endpointConfig.getUserProperties()); + this.id = Long.toHexString(ids.getAndIncrement()); + + InstanceManager instanceManager = webSocketContainer.getInstanceManager(); + if (instanceManager == null) { + instanceManager = InstanceManagerBindings.get(applicationClassLoader); + } + if (instanceManager != null) { + try { + instanceManager.newInstance(localEndpoint); + } catch (Exception e) { + throw new DeploymentException(sm.getString("wsSession.instanceNew"), e); + } + } + + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.created", id)); + } + + messageBufferText = CharBuffer.allocate(maxTextMessageBufferSize); + } + + public static String wsSession_test() { + return sm.getString("wsSession.instanceNew"); + } + + + @Override + public WebSocketContainer getContainer() { + checkState(); + return webSocketContainer; + } + + + @Override + public void addMessageHandler(MessageHandler listener) { + Class<?> target = Util.getMessageType(listener); + doAddMessageHandler(target, listener); + } + + + @Override + public <T> void addMessageHandler(Class<T> clazz, Partial<T> handler) + throws IllegalStateException { + doAddMessageHandler(clazz, handler); + } + + + @Override + public <T> void addMessageHandler(Class<T> clazz, Whole<T> handler) + throws IllegalStateException { + doAddMessageHandler(clazz, handler); + } + + + @SuppressWarnings("unchecked") + private void doAddMessageHandler(Class<?> target, MessageHandler listener) { + checkState(); + + // Message handlers that require decoders may map to text messages, + // binary messages, both or neither. + + // The frame processing code expects binary message handlers to + // accept ByteBuffer + + // Use the POJO message handler wrappers as they are designed to wrap + // arbitrary objects with MessageHandlers and can wrap MessageHandlers + // just as easily. + + Set<MessageHandlerResult> mhResults = Util.getMessageHandlers(target, listener, + endpointConfig, this); + + for (MessageHandlerResult mhResult : mhResults) { + switch (mhResult.getType()) { + case TEXT: { + if (textMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText")); + } + textMessageHandler = mhResult.getHandler(); + break; + } + case BINARY: { + if (binaryMessageHandler != null) { + throw new IllegalStateException( + sm.getString("wsSession.duplicateHandlerBinary")); + } + binaryMessageHandler = mhResult.getHandler(); + break; + } + case PONG: { + if (pongMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong")); + } + MessageHandler handler = mhResult.getHandler(); + if (handler instanceof MessageHandler.Whole<?>) { + pongMessageHandler = (MessageHandler.Whole<PongMessage>) handler; + } else { + throw new IllegalStateException( + sm.getString("wsSession.invalidHandlerTypePong")); + } + + break; + } + default: { + throw new IllegalArgumentException( + sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType())); + } + } + } + } + + + @Override + public Set<MessageHandler> getMessageHandlers() { + checkState(); + Set<MessageHandler> result = new HashSet<>(); + if (binaryMessageHandler != null) { + result.add(binaryMessageHandler); + } + if (textMessageHandler != null) { + result.add(textMessageHandler); + } + if (pongMessageHandler != null) { + result.add(pongMessageHandler); + } + return result; + } + + + @Override + public void removeMessageHandler(MessageHandler listener) { + checkState(); + if (listener == null) { + return; + } + + MessageHandler wrapped = null; + + if (listener instanceof WrappedMessageHandler) { + wrapped = ((WrappedMessageHandler) listener).getWrappedHandler(); + } + + if (wrapped == null) { + wrapped = listener; + } + + boolean removed = false; + if (wrapped.equals(textMessageHandler) || listener.equals(textMessageHandler)) { + textMessageHandler = null; + removed = true; + } + + if (wrapped.equals(binaryMessageHandler) || listener.equals(binaryMessageHandler)) { + binaryMessageHandler = null; + removed = true; + } + + if (wrapped.equals(pongMessageHandler) || listener.equals(pongMessageHandler)) { + pongMessageHandler = null; + removed = true; + } + + if (!removed) { + // ISE for now. Could swallow this silently / log this if the ISE + // becomes a problem + throw new IllegalStateException( + sm.getString("wsSession.removeHandlerFailed", listener)); + } + } + + + @Override + public String getProtocolVersion() { + checkState(); + return Constants.WS_VERSION_HEADER_VALUE; + } + + + @Override + public String getNegotiatedSubprotocol() { + checkState(); + return subProtocol; + } + + + @Override + public List<Extension> getNegotiatedExtensions() { + checkState(); + return negotiatedExtensions; + } + + + @Override + public boolean isSecure() { + checkState(); + return secure; + } + + + @Override + public boolean isOpen() { + return state == State.OPEN; + } + + + @Override + public long getMaxIdleTimeout() { + checkState(); + return maxIdleTimeout; + } + + + @Override + public void setMaxIdleTimeout(long timeout) { + checkState(); + this.maxIdleTimeout = timeout; + } + + + @Override + public void setMaxBinaryMessageBufferSize(int max) { + checkState(); + this.maxBinaryMessageBufferSize = max; + } + + + @Override + public int getMaxBinaryMessageBufferSize() { + checkState(); + return maxBinaryMessageBufferSize; + } + + + @Override + public void setMaxTextMessageBufferSize(int max) { + checkState(); + this.maxTextMessageBufferSize = max; + } + + + @Override + public int getMaxTextMessageBufferSize() { + checkState(); + return maxTextMessageBufferSize; + } + + + @Override + public Set<Session> getOpenSessions() { + checkState(); + return webSocketContainer.getOpenSessions(localEndpoint); + } + + + @Override + public RemoteEndpoint.Async getAsyncRemote() { + checkState(); + return remoteEndpointAsync; + } + + + @Override + public RemoteEndpoint.Basic getBasicRemote() { + checkState(); + return remoteEndpointBasic; + } + + + @Override + public void close() throws IOException { + close(new CloseReason(CloseCodes.NORMAL_CLOSURE, "")); + } + + + @Override + public void close(CloseReason closeReason) throws IOException { + doClose(closeReason, closeReason); + } + + + /** + * WebSocket 1.0. Section 2.1.5. + * Need internal close method as spec requires that the local endpoint + * receives a 1006 on timeout. + * + * @param closeReasonMessage The close reason to pass to the remote endpoint + * @param closeReasonLocal The close reason to pass to the local endpoint + */ + public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal) { + // Double-checked locking. OK because state is volatile + if (state != State.OPEN) { + return; + } + + synchronized (stateLock) { + if (state != State.OPEN) { + return; + } + + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.doClose", id)); + } + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + + state = State.OUTPUT_CLOSED; + + sendCloseMessage(closeReasonMessage); + fireEndpointOnClose(closeReasonLocal); + } + + IOException ioe = new IOException(sm.getString("wsSession.messageFailed")); + SendResult sr = new SendResult(ioe); + for (FutureToSendHandler f2sh : futures.keySet()) { + f2sh.onResult(sr); + } + } + + + /** + * Called when a close message is received. Should only ever happen once. + * Also called after a protocol error when the ProtocolHandler needs to + * force the closing of the connection. + * + * @param closeReason The reason contained within the received close + * message. + */ + public void onClose(CloseReason closeReason) { + + synchronized (stateLock) { + if (state != State.CLOSED) { + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + if (state == State.OPEN) { + state = State.OUTPUT_CLOSED; + sendCloseMessage(closeReason); + fireEndpointOnClose(closeReason); + } + state = State.CLOSED; + + // Close the socket + wsRemoteEndpoint.close(); + } + } + } + + + public void onClose() { + + synchronized (stateLock) { + if (state != State.CLOSED) { + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + if (state == State.OPEN) { + state = State.OUTPUT_CLOSED; + fireEndpointOnClose(new CloseReason( + CloseReason.CloseCodes.NORMAL_CLOSURE, "")); + } + state = State.CLOSED; + + // Close the socket + wsRemoteEndpoint.close(); + } + } + } + + + private void fireEndpointOnClose(CloseReason closeReason) { + + // Fire the onClose event + Throwable throwable = null; + InstanceManager instanceManager = webSocketContainer.getInstanceManager(); + if (instanceManager == null) { + instanceManager = InstanceManagerBindings.get(applicationClassLoader); + } + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + localEndpoint.onClose(this, closeReason); + } catch (Throwable t1) { + ExceptionUtils.handleThrowable(t1); + throwable = t1; + } finally { + if (instanceManager != null) { + try { + instanceManager.destroyInstance(localEndpoint); + } catch (Throwable t2) { + ExceptionUtils.handleThrowable(t2); + if (throwable == null) { + throwable = t2; + } + } + } + t.setContextClassLoader(cl); + } + + if (throwable != null) { + fireEndpointOnError(throwable); + } + } + + + private void fireEndpointOnError(Throwable throwable) { + + // Fire the onError event + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + localEndpoint.onError(this, throwable); + } finally { + t.setContextClassLoader(cl); + } + } + + + private void sendCloseMessage(CloseReason closeReason) { + // 125 is maximum size for the payload of a control message + ByteBuffer msg = ByteBuffer.allocate(125); + CloseCode closeCode = closeReason.getCloseCode(); + // CLOSED_ABNORMALLY should not be put on the wire + if (closeCode == CloseCodes.CLOSED_ABNORMALLY) { + // PROTOCOL_ERROR is probably better than GOING_AWAY here + msg.putShort((short) CloseCodes.PROTOCOL_ERROR.getCode()); + } else { + msg.putShort((short) closeCode.getCode()); + } + + String reason = closeReason.getReasonPhrase(); + if (reason != null && reason.length() > 0) { + appendCloseReasonWithTruncation(msg, reason); + } + msg.flip(); + try { + wsRemoteEndpoint.sendMessageBlock(Constants.OPCODE_CLOSE, msg, true); + } catch (IOException | WritePendingException e) { + // Failed to send close message. Close the socket and let the caller + // deal with the Exception + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.sendCloseFail", id), e); + } + wsRemoteEndpoint.close(); + // Failure to send a close message is not unexpected in the case of + // an abnormal closure (usually triggered by a failure to read/write + // from/to the client. In this case do not trigger the endpoint's + // error handling + if (closeCode != CloseCodes.CLOSED_ABNORMALLY) { + localEndpoint.onError(this, e); + } + } finally { + webSocketContainer.unregisterSession(localEndpoint, this); + } + } + + + /** + * Use protected so unit tests can access this method directly. + * @param msg The message + * @param reason The reason + */ + protected static void appendCloseReasonWithTruncation(ByteBuffer msg, String reason) { + // Once the close code has been added there are a maximum of 123 bytes + // left for the reason phrase. If it is truncated then care needs to be + // taken to ensure the bytes are not truncated in the middle of a + // multi-byte UTF-8 character. + byte[] reasonBytes = reason.getBytes(StandardCharsets.UTF_8); + + if (reasonBytes.length <= 123) { + // No need to truncate + msg.put(reasonBytes); + } else { + // Need to truncate + int remaining = 123 - ELLIPSIS_BYTES_LEN; + int pos = 0; + byte[] bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8); + while (remaining >= bytesNext.length) { + msg.put(bytesNext); + remaining -= bytesNext.length; + pos++; + bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8); + } + msg.put(ELLIPSIS_BYTES); + } + } + + + /** + * Make the session aware of a {@link FutureToSendHandler} that will need to + * be forcibly closed if the session closes before the + * {@link FutureToSendHandler} completes. + * @param f2sh The handler + */ + protected void registerFuture(FutureToSendHandler f2sh) { + // Ideally, this code should sync on stateLock so that the correct + // action is taken based on the current state of the connection. + // However, a sync on stateLock can't be used here as it will create the + // possibility of a dead-lock. See BZ 61183. + // Therefore, a slightly less efficient approach is used. + + // Always register the future. + futures.put(f2sh, f2sh); + + if (state == State.OPEN) { + // The session is open. The future has been registered with the open + // session. Normal processing continues. + return; + } + + // The session is closed. The future may or may not have been registered + // in time for it to be processed during session closure. + + if (f2sh.isDone()) { + // The future has completed. It is not known if the future was + // completed normally by the I/O layer or in error by doClose(). It + // doesn't matter which. There is nothing more to do here. + return; + } + + // The session is closed. The Future had not completed when last checked. + // There is a small timing window that means the Future may have been + // completed since the last check. There is also the possibility that + // the Future was not registered in time to be cleaned up during session + // close. + // Attempt to complete the Future with an error result as this ensures + // that the Future completes and any client code waiting on it does not + // hang. It is slightly inefficient since the Future may have been + // completed in another thread or another thread may be about to + // complete the Future but knowing if this is the case requires the sync + // on stateLock (see above). + // Note: If multiple attempts are made to complete the Future, the + // second and subsequent attempts are ignored. + + IOException ioe = new IOException(sm.getString("wsSession.messageFailed")); + SendResult sr = new SendResult(ioe); + f2sh.onResult(sr); + } + + + /** + * Remove a {@link FutureToSendHandler} from the set of tracked instances. + * @param f2sh The handler + */ + protected void unregisterFuture(FutureToSendHandler f2sh) { + futures.remove(f2sh); + } + + + @Override + public URI getRequestURI() { + checkState(); + return requestUri; + } + + + @Override + public Map<String, List<String>> getRequestParameterMap() { + checkState(); + return requestParameterMap; + } + + + @Override + public String getQueryString() { + checkState(); + return queryString; + } + + + @Override + public Principal getUserPrincipal() { + checkState(); + return userPrincipal; + } + + + @Override + public Map<String, String> getPathParameters() { + checkState(); + return pathParameters; + } + + + @Override + public String getId() { + return id; + } + + + @Override + public Map<String, Object> getUserProperties() { + checkState(); + return userProperties; + } + + + public Endpoint getLocal() { + return localEndpoint; + } + + + public String getHttpSessionId() { + return httpSessionId; + } + + private ByteBuffer rawFragments; + + public void processFrame(ByteBuffer buf, byte opCode, boolean last) + throws IOException + { + if (state == State.CLOSED) { + return; + } + + if (opCode == Constants.OPCODE_CONTINUATION) { + opCode = startOpCode; + + if (rawFragments != null && rawFragments.position() > 0) { + rawFragments.put(buf); + rawFragments.flip(); + buf = rawFragments; + } + } else { + if (!last && (opCode == Constants.OPCODE_BINARY || + opCode == Constants.OPCODE_TEXT)) { + startOpCode = opCode; + + if (rawFragments != null) { + rawFragments.clear(); + } + } + } + + if (last) { + startOpCode = Constants.OPCODE_CONTINUATION; + } + + if (opCode == Constants.OPCODE_PONG) { + if (pongMessageHandler != null) { + final ByteBuffer b = buf; + + PongMessage pongMessage = new PongMessage() { + @Override + public ByteBuffer getApplicationData() { + return b; + } + }; + + pongMessageHandler.onMessage(pongMessage); + } + } + + if (opCode == Constants.OPCODE_CLOSE) { + CloseReason closeReason; + + if (buf.remaining() >= 2) { + short closeCode = buf.order(ByteOrder.BIG_ENDIAN).getShort(); + + closeReason = new CloseReason( + CloseReason.CloseCodes.getCloseCode(closeCode), + buf.asCharBuffer().toString()); + } else { + closeReason = new CloseReason( + CloseReason.CloseCodes.NORMAL_CLOSURE, ""); + } + + onClose(closeReason); + } + + if (opCode == Constants.OPCODE_BINARY) { + onMessage(buf, last); + } + + if (opCode == Constants.OPCODE_TEXT) { + if (messageBufferText.position() == 0 && maxTextMessageBufferSize != messageBufferText.capacity()) { + messageBufferText = CharBuffer.allocate(maxTextMessageBufferSize); + } + + CoderResult cr = utf8DecoderMessage.decode(buf, messageBufferText, last); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (hasTextPartial()) { + do { + onMessage(messageBufferText, false); + + cr = utf8DecoderMessage.decode(buf, messageBufferText, last); + } while (cr.isOverflow()); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow() && !last) { + updateRawFragments(buf, last); + + if (hasTextPartial()) { + onMessage(messageBufferText, false); + } + + return; + } + + if (last) { + utf8DecoderMessage.reset(); + } + + updateRawFragments(buf, last); + + onMessage(messageBufferText, last); + } + } + + + private boolean hasTextPartial() { + return textMessageHandler instanceof MessageHandler.Partial<?>; + } + + + private void onMessage(CharBuffer buf, boolean last) throws IOException { + buf.flip(); + try { + onMessage(buf.toString(), last); + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + buf.clear(); + } + } + + + private void updateRawFragments(ByteBuffer buf, boolean last) { + if (!last && buf.remaining() > 0) { + if (buf == rawFragments) { + buf.compact(); + } else { + if (rawFragments == null || (rawFragments.position() == 0 && maxTextMessageBufferSize != rawFragments.capacity())) { + rawFragments = ByteBuffer.allocateDirect(maxTextMessageBufferSize); + } + rawFragments.put(buf); + } + } else { + if (rawFragments != null) { + rawFragments.clear(); + } + } + } + + + @SuppressWarnings("unchecked") + public void onMessage(String text, boolean last) { + if (hasTextPartial()) { + ((MessageHandler.Partial<String>) textMessageHandler).onMessage(text, last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole<String>) textMessageHandler).onMessage(text); + } + } + + + @SuppressWarnings("unchecked") + public void onMessage(ByteBuffer buf, boolean last) + throws IOException + { + if (binaryMessageHandler instanceof MessageHandler.Partial<?>) { + ((MessageHandler.Partial<ByteBuffer>) binaryMessageHandler).onMessage(buf, last); + } else { + if (last && (binaryBuffer == null || binaryBuffer.position() == 0)) { + ((MessageHandler.Whole<ByteBuffer>) binaryMessageHandler).onMessage(buf); + return; + } + + if (binaryBuffer == null || + (binaryBuffer.position() == 0 && binaryBuffer.capacity() != maxBinaryMessageBufferSize)) + { + binaryBuffer = ByteBuffer.allocateDirect(maxBinaryMessageBufferSize); + } + + if (binaryBuffer.remaining() < buf.remaining()) { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + + binaryBuffer.put(buf); + + if (last) { + binaryBuffer.flip(); + try { + ((MessageHandler.Whole<ByteBuffer>) binaryMessageHandler).onMessage(binaryBuffer); + } finally { + binaryBuffer.clear(); + } + } + } + } + + + private void handleThrowableOnSend(Throwable t) throws WsIOException { + ExceptionUtils.handleThrowable(t); + getLocal().onError(this, t); + CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, + sm.getString("wsFrame.ioeTriggeredClose")); + throw new WsIOException(cr); + } + + + protected MessageHandler getTextMessageHandler() { + return textMessageHandler; + } + + + protected MessageHandler getBinaryMessageHandler() { + return binaryMessageHandler; + } + + + protected MessageHandler.Whole<PongMessage> getPongMessageHandler() { + return pongMessageHandler; + } + + + protected void updateLastActive() { + lastActive = System.currentTimeMillis(); + } + + + protected void checkExpiration() { + long timeout = maxIdleTimeout; + if (timeout < 1) { + return; + } + + if (System.currentTimeMillis() - lastActive > timeout) { + String msg = sm.getString("wsSession.timeout", getId()); + if (log.isDebugEnabled()) { + log.debug(msg); + } + doClose(new CloseReason(CloseCodes.GOING_AWAY, msg), + new CloseReason(CloseCodes.CLOSED_ABNORMALLY, msg)); + } + } + + + private void checkState() { + if (state == State.CLOSED) { + /* + * As per RFC 6455, a WebSocket connection is considered to be + * closed once a peer has sent and received a WebSocket close frame. + */ + throw new IllegalStateException(sm.getString("wsSession.closed", id)); + } + } + + private enum State { + OPEN, + OUTPUT_CLOSED, + CLOSED + } +} diff --git a/src/java/nginx/unit/websocket/WsWebSocketContainer.java b/src/java/nginx/unit/websocket/WsWebSocketContainer.java new file mode 100644 index 00000000..282665ef --- /dev/null +++ b/src/java/nginx/unit/websocket/WsWebSocketContainer.java @@ -0,0 +1,1123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.net.ProxySelector; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousChannelGroup; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.charset.StandardCharsets; +import java.security.KeyStore; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.TrustManagerFactory; +import javax.websocket.ClientEndpoint; +import javax.websocket.ClientEndpointConfig; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.util.buf.StringUtils; +import org.apache.tomcat.util.codec.binary.Base64; +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.pojo.PojoEndpointClient; + +public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess { + + private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class); + private static final Random RANDOM = new Random(); + private static final byte[] CRLF = new byte[] { 13, 10 }; + + private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1); + private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1); + private static final byte[] HTTP_VERSION_BYTES = + " HTTP/1.1\r\n".getBytes(StandardCharsets.ISO_8859_1); + + private volatile AsynchronousChannelGroup asynchronousChannelGroup = null; + private final Object asynchronousChannelGroupLock = new Object(); + + private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static + private final Map<Endpoint, Set<WsSession>> endpointSessionMap = + new HashMap<>(); + private final Map<WsSession,WsSession> sessions = new ConcurrentHashMap<>(); + private final Object endPointSessionMapLock = new Object(); + + private long defaultAsyncTimeout = -1; + private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile long defaultMaxSessionIdleTimeout = 0; + private int backgroundProcessCount = 0; + private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD; + + private InstanceManager instanceManager; + + InstanceManager getInstanceManager() { + return instanceManager; + } + + protected void setInstanceManager(InstanceManager instanceManager) { + this.instanceManager = instanceManager; + } + + @Override + public Session connectToServer(Object pojo, URI path) + throws DeploymentException { + + ClientEndpoint annotation = + pojo.getClass().getAnnotation(ClientEndpoint.class); + if (annotation == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.missingAnnotation", + pojo.getClass().getName())); + } + + Endpoint ep = new PojoEndpointClient(pojo, Arrays.asList(annotation.decoders())); + + Class<? extends ClientEndpointConfig.Configurator> configuratorClazz = + annotation.configurator(); + + ClientEndpointConfig.Configurator configurator = null; + if (!ClientEndpointConfig.Configurator.class.equals( + configuratorClazz)) { + try { + configurator = configuratorClazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.defaultConfiguratorFail"), e); + } + } + + ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create(); + // Avoid NPE when using RI API JAR - see BZ 56343 + if (configurator != null) { + builder.configurator(configurator); + } + ClientEndpointConfig config = builder. + decoders(Arrays.asList(annotation.decoders())). + encoders(Arrays.asList(annotation.encoders())). + preferredSubprotocols(Arrays.asList(annotation.subprotocols())). + build(); + return connectToServer(ep, config, path); + } + + + @Override + public Session connectToServer(Class<?> annotatedEndpointClass, URI path) + throws DeploymentException { + + Object pojo; + try { + pojo = annotatedEndpointClass.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.endpointCreateFail", + annotatedEndpointClass.getName()), e); + } + + return connectToServer(pojo, path); + } + + + @Override + public Session connectToServer(Class<? extends Endpoint> clazz, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException { + + Endpoint endpoint; + try { + endpoint = clazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.endpointCreateFail", clazz.getName()), + e); + } + + return connectToServer(endpoint, clientEndpointConfiguration, path); + } + + + @Override + public Session connectToServer(Endpoint endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException { + return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, new HashSet<>()); + } + + private Session connectToServerRecursive(Endpoint endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path, + Set<URI> redirectSet) + throws DeploymentException { + + boolean secure = false; + ByteBuffer proxyConnect = null; + URI proxyPath; + + // Validate scheme (and build proxyPath) + String scheme = path.getScheme(); + if ("ws".equalsIgnoreCase(scheme)) { + proxyPath = URI.create("http" + path.toString().substring(2)); + } else if ("wss".equalsIgnoreCase(scheme)) { + proxyPath = URI.create("https" + path.toString().substring(3)); + secure = true; + } else { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.pathWrongScheme", scheme)); + } + + // Validate host + String host = path.getHost(); + if (host == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.pathNoHost")); + } + int port = path.getPort(); + + SocketAddress sa = null; + + // Check to see if a proxy is configured. Javadoc indicates return value + // will never be null + List<Proxy> proxies = ProxySelector.getDefault().select(proxyPath); + Proxy selectedProxy = null; + for (Proxy proxy : proxies) { + if (proxy.type().equals(Proxy.Type.HTTP)) { + sa = proxy.address(); + if (sa instanceof InetSocketAddress) { + InetSocketAddress inet = (InetSocketAddress) sa; + if (inet.isUnresolved()) { + sa = new InetSocketAddress(inet.getHostName(), inet.getPort()); + } + } + selectedProxy = proxy; + break; + } + } + + // If the port is not explicitly specified, compute it based on the + // scheme + if (port == -1) { + if ("ws".equalsIgnoreCase(scheme)) { + port = 80; + } else { + // Must be wss due to scheme validation above + port = 443; + } + } + + // If sa is null, no proxy is configured so need to create sa + if (sa == null) { + sa = new InetSocketAddress(host, port); + } else { + proxyConnect = createProxyRequest(host, port); + } + + // Create the initial HTTP request to open the WebSocket connection + Map<String, List<String>> reqHeaders = createRequestHeaders(host, port, + clientEndpointConfiguration); + clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders); + if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null + && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) { + List<String> originValues = new ArrayList<>(1); + originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE); + reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues); + } + ByteBuffer request = createRequest(path, reqHeaders); + + AsynchronousSocketChannel socketChannel; + try { + socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup()); + } catch (IOException ioe) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.asynchronousSocketChannelFail"), ioe); + } + + Map<String,Object> userProperties = clientEndpointConfiguration.getUserProperties(); + + // Get the connection timeout + long timeout = Constants.IO_TIMEOUT_MS_DEFAULT; + String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY); + if (timeoutValue != null) { + timeout = Long.valueOf(timeoutValue).intValue(); + } + + // Set-up + // Same size as the WsFrame input buffer + ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize()); + String subProtocol; + boolean success = false; + List<Extension> extensionsAgreed = new ArrayList<>(); + Transformation transformation = null; + + // Open the connection + Future<Void> fConnect = socketChannel.connect(sa); + AsyncChannelWrapper channel = null; + + if (proxyConnect != null) { + try { + fConnect.get(timeout, TimeUnit.MILLISECONDS); + // Proxy CONNECT is clear text + channel = new AsyncChannelWrapperNonSecure(socketChannel); + writeRequest(channel, proxyConnect, timeout); + HttpResponse httpResponse = processResponse(response, channel, timeout); + if (httpResponse.getStatus() != 200) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.proxyConnectFail", selectedProxy, + Integer.toString(httpResponse.getStatus()))); + } + } catch (TimeoutException | InterruptedException | ExecutionException | + EOFException e) { + if (channel != null) { + channel.close(); + } + throw new DeploymentException( + sm.getString("wsWebSocketContainer.httpRequestFailed"), e); + } + } + + if (secure) { + // Regardless of whether a non-secure wrapper was created for a + // proxy CONNECT, need to use TLS from this point on so wrap the + // original AsynchronousSocketChannel + SSLEngine sslEngine = createSSLEngine(userProperties, host, port); + channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine); + } else if (channel == null) { + // Only need to wrap as this point if it wasn't wrapped to process a + // proxy CONNECT + channel = new AsyncChannelWrapperNonSecure(socketChannel); + } + + try { + fConnect.get(timeout, TimeUnit.MILLISECONDS); + + Future<Void> fHandshake = channel.handshake(); + fHandshake.get(timeout, TimeUnit.MILLISECONDS); + + writeRequest(channel, request, timeout); + + HttpResponse httpResponse = processResponse(response, channel, timeout); + + // Check maximum permitted redirects + int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT; + String maxRedirectsValue = + (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY); + if (maxRedirectsValue != null) { + maxRedirects = Integer.parseInt(maxRedirectsValue); + } + + if (httpResponse.status != 101) { + if(isRedirectStatus(httpResponse.status)){ + List<String> locationHeader = + httpResponse.getHandshakeResponse().getHeaders().get( + Constants.LOCATION_HEADER_NAME); + + if (locationHeader == null || locationHeader.isEmpty() || + locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.missingLocationHeader", + Integer.toString(httpResponse.status))); + } + + URI redirectLocation = URI.create(locationHeader.get(0)).normalize(); + + if (!redirectLocation.isAbsolute()) { + redirectLocation = path.resolve(redirectLocation); + } + + String redirectScheme = redirectLocation.getScheme().toLowerCase(); + + if (redirectScheme.startsWith("http")) { + redirectLocation = new URI(redirectScheme.replace("http", "ws"), + redirectLocation.getUserInfo(), redirectLocation.getHost(), + redirectLocation.getPort(), redirectLocation.getPath(), + redirectLocation.getQuery(), redirectLocation.getFragment()); + } + + if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.redirectThreshold", redirectLocation, + Integer.toString(redirectSet.size()), + Integer.toString(maxRedirects))); + } + + return connectToServerRecursive(endpoint, clientEndpointConfiguration, redirectLocation, redirectSet); + + } + + else if (httpResponse.status == 401) { + + if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.failedAuthentication", + Integer.valueOf(httpResponse.status))); + } + + List<String> wwwAuthenticateHeaders = httpResponse.getHandshakeResponse() + .getHeaders().get(Constants.WWW_AUTHENTICATE_HEADER_NAME); + + if (wwwAuthenticateHeaders == null || wwwAuthenticateHeaders.isEmpty() || + wwwAuthenticateHeaders.get(0) == null || wwwAuthenticateHeaders.get(0).isEmpty()) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.missingWWWAuthenticateHeader", + Integer.toString(httpResponse.status))); + } + + String authScheme = wwwAuthenticateHeaders.get(0).split("\\s+", 2)[0]; + String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1) + .split("\\s", 3)[1]; + + Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme); + + if (auth == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.unsupportedAuthScheme", + Integer.valueOf(httpResponse.status), authScheme)); + } + + userProperties.put(Constants.AUTHORIZATION_HEADER_NAME, auth.getAuthorization( + requestUri, wwwAuthenticateHeaders.get(0), userProperties)); + + return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, redirectSet); + + } + + else { + throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", + Integer.toString(httpResponse.status))); + } + } + HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse(); + clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse); + + // Sub-protocol + List<String> protocolHeaders = handshakeResponse.getHeaders().get( + Constants.WS_PROTOCOL_HEADER_NAME); + if (protocolHeaders == null || protocolHeaders.size() == 0) { + subProtocol = null; + } else if (protocolHeaders.size() == 1) { + subProtocol = protocolHeaders.get(0); + } else { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.invalidSubProtocol")); + } + + // Extensions + // Should normally only be one header but handle the case of + // multiple headers + List<String> extHeaders = handshakeResponse.getHeaders().get( + Constants.WS_EXTENSIONS_HEADER_NAME); + if (extHeaders != null) { + for (String extHeader : extHeaders) { + Util.parseExtensionHeader(extensionsAgreed, extHeader); + } + } + + // Build the transformations + TransformationFactory factory = TransformationFactory.getInstance(); + for (Extension extension : extensionsAgreed) { + List<List<Extension.Parameter>> wrapper = new ArrayList<>(1); + wrapper.add(extension.getParameters()); + Transformation t = factory.create(extension.getName(), wrapper, false); + if (t == null) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidExtensionParameters")); + } + if (transformation == null) { + transformation = t; + } else { + transformation.setNext(t); + } + } + + success = true; + } catch (ExecutionException | InterruptedException | SSLException | + EOFException | TimeoutException | URISyntaxException | AuthenticationException e) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.httpRequestFailed"), e); + } finally { + if (!success) { + channel.close(); + } + } + + // Switch to WebSocket + WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel); + + WsSession wsSession = new WsSession(endpoint, wsRemoteEndpointClient, + this, null, null, null, null, null, extensionsAgreed, + subProtocol, Collections.<String,String>emptyMap(), secure, + clientEndpointConfiguration, null); + + WsFrameClient wsFrameClient = new WsFrameClient(response, channel, + wsSession, transformation); + // WsFrame adds the necessary final transformations. Copy the + // completed transformation chain to the remote end point. + wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation()); + + endpoint.onOpen(wsSession, clientEndpointConfiguration); + registerSession(endpoint, wsSession); + + /* It is possible that the server sent one or more messages as soon as + * the WebSocket connection was established. Depending on the exact + * timing of when those messages were sent they could be sat in the + * input buffer waiting to be read and will not trigger a "data + * available to read" event. Therefore, it is necessary to process the + * input buffer here. Note that this happens on the current thread which + * means that this thread will be used for any onMessage notifications. + * This is a special case. Subsequent "data available to read" events + * will be handled by threads from the AsyncChannelGroup's executor. + */ + wsFrameClient.startInputProcessing(); + + return wsSession; + } + + + private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request, + long timeout) throws TimeoutException, InterruptedException, ExecutionException { + int toWrite = request.limit(); + + Future<Integer> fWrite = channel.write(request); + Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); + toWrite -= thisWrite.intValue(); + + while (toWrite > 0) { + fWrite = channel.write(request); + thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); + toWrite -= thisWrite.intValue(); + } + } + + + private static boolean isRedirectStatus(int httpResponseCode) { + + boolean isRedirect = false; + + switch (httpResponseCode) { + case Constants.MULTIPLE_CHOICES: + case Constants.MOVED_PERMANENTLY: + case Constants.FOUND: + case Constants.SEE_OTHER: + case Constants.USE_PROXY: + case Constants.TEMPORARY_REDIRECT: + isRedirect = true; + break; + default: + break; + } + + return isRedirect; + } + + + private static ByteBuffer createProxyRequest(String host, int port) { + StringBuilder request = new StringBuilder(); + request.append("CONNECT "); + request.append(host); + request.append(':'); + request.append(port); + + request.append(" HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keepalive\r\nHost: "); + request.append(host); + request.append(':'); + request.append(port); + + request.append("\r\n\r\n"); + + byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1); + return ByteBuffer.wrap(bytes); + } + + protected void registerSession(Endpoint endpoint, WsSession wsSession) { + + if (!wsSession.isOpen()) { + // The session was closed during onOpen. No need to register it. + return; + } + synchronized (endPointSessionMapLock) { + if (endpointSessionMap.size() == 0) { + BackgroundProcessManager.getInstance().register(this); + } + Set<WsSession> wsSessions = endpointSessionMap.get(endpoint); + if (wsSessions == null) { + wsSessions = new HashSet<>(); + endpointSessionMap.put(endpoint, wsSessions); + } + wsSessions.add(wsSession); + } + sessions.put(wsSession, wsSession); + } + + + protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { + + synchronized (endPointSessionMapLock) { + Set<WsSession> wsSessions = endpointSessionMap.get(endpoint); + if (wsSessions != null) { + wsSessions.remove(wsSession); + if (wsSessions.size() == 0) { + endpointSessionMap.remove(endpoint); + } + } + if (endpointSessionMap.size() == 0) { + BackgroundProcessManager.getInstance().unregister(this); + } + } + sessions.remove(wsSession); + } + + + Set<Session> getOpenSessions(Endpoint endpoint) { + HashSet<Session> result = new HashSet<>(); + synchronized (endPointSessionMapLock) { + Set<WsSession> sessions = endpointSessionMap.get(endpoint); + if (sessions != null) { + result.addAll(sessions); + } + } + return result; + } + + private static Map<String, List<String>> createRequestHeaders(String host, int port, + ClientEndpointConfig clientEndpointConfiguration) { + + Map<String, List<String>> headers = new HashMap<>(); + List<Extension> extensions = clientEndpointConfiguration.getExtensions(); + List<String> subProtocols = clientEndpointConfiguration.getPreferredSubprotocols(); + Map<String, Object> userProperties = clientEndpointConfiguration.getUserProperties(); + + if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { + List<String> authValues = new ArrayList<>(1); + authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME)); + headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues); + } + + // Host header + List<String> hostValues = new ArrayList<>(1); + if (port == -1) { + hostValues.add(host); + } else { + hostValues.add(host + ':' + port); + } + + headers.put(Constants.HOST_HEADER_NAME, hostValues); + + // Upgrade header + List<String> upgradeValues = new ArrayList<>(1); + upgradeValues.add(Constants.UPGRADE_HEADER_VALUE); + headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues); + + // Connection header + List<String> connectionValues = new ArrayList<>(1); + connectionValues.add(Constants.CONNECTION_HEADER_VALUE); + headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues); + + // WebSocket version header + List<String> wsVersionValues = new ArrayList<>(1); + wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE); + headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues); + + // WebSocket key + List<String> wsKeyValues = new ArrayList<>(1); + wsKeyValues.add(generateWsKeyValue()); + headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues); + + // WebSocket sub-protocols + if (subProtocols != null && subProtocols.size() > 0) { + headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols); + } + + // WebSocket extensions + if (extensions != null && extensions.size() > 0) { + headers.put(Constants.WS_EXTENSIONS_HEADER_NAME, + generateExtensionHeaders(extensions)); + } + + return headers; + } + + + private static List<String> generateExtensionHeaders(List<Extension> extensions) { + List<String> result = new ArrayList<>(extensions.size()); + for (Extension extension : extensions) { + StringBuilder header = new StringBuilder(); + header.append(extension.getName()); + for (Extension.Parameter param : extension.getParameters()) { + header.append(';'); + header.append(param.getName()); + String value = param.getValue(); + if (value != null && value.length() > 0) { + header.append('='); + header.append(value); + } + } + result.add(header.toString()); + } + return result; + } + + + private static String generateWsKeyValue() { + byte[] keyBytes = new byte[16]; + RANDOM.nextBytes(keyBytes); + return Base64.encodeBase64String(keyBytes); + } + + + private static ByteBuffer createRequest(URI uri, Map<String,List<String>> reqHeaders) { + ByteBuffer result = ByteBuffer.allocate(4 * 1024); + + // Request line + result.put(GET_BYTES); + if (null == uri.getPath() || "".equals(uri.getPath())) { + result.put(ROOT_URI_BYTES); + } else { + result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1)); + } + String query = uri.getRawQuery(); + if (query != null) { + result.put((byte) '?'); + result.put(query.getBytes(StandardCharsets.ISO_8859_1)); + } + result.put(HTTP_VERSION_BYTES); + + // Headers + for (Entry<String, List<String>> entry : reqHeaders.entrySet()) { + result = addHeader(result, entry.getKey(), entry.getValue()); + } + + // Terminating CRLF + result.put(CRLF); + + result.flip(); + + return result; + } + + + private static ByteBuffer addHeader(ByteBuffer result, String key, List<String> values) { + if (values.isEmpty()) { + return result; + } + + result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, CRLF); + + return result; + } + + + private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) { + if (bytes.length > input.remaining()) { + int newSize; + if (bytes.length > input.capacity()) { + newSize = 2 * bytes.length; + } else { + newSize = input.capacity() * 2; + } + ByteBuffer expanded = ByteBuffer.allocate(newSize); + input.flip(); + expanded.put(input); + input = expanded; + } + return input.put(bytes); + } + + + /** + * Process response, blocking until HTTP response has been fully received. + * @throws ExecutionException + * @throws InterruptedException + * @throws DeploymentException + * @throws TimeoutException + */ + private HttpResponse processResponse(ByteBuffer response, + AsyncChannelWrapper channel, long timeout) throws InterruptedException, + ExecutionException, DeploymentException, EOFException, + TimeoutException { + + Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>(); + + int status = 0; + boolean readStatus = false; + boolean readHeaders = false; + String line = null; + while (!readHeaders) { + // On entering loop buffer will be empty and at the start of a new + // loop the buffer will have been fully read. + response.clear(); + // Blocking read + Future<Integer> read = channel.read(response); + Integer bytesRead = read.get(timeout, TimeUnit.MILLISECONDS); + if (bytesRead.intValue() == -1) { + throw new EOFException(); + } + response.flip(); + while (response.hasRemaining() && !readHeaders) { + if (line == null) { + line = readLine(response); + } else { + line += readLine(response); + } + if ("\r\n".equals(line)) { + readHeaders = true; + } else if (line.endsWith("\r\n")) { + if (readStatus) { + parseHeaders(line, headers); + } else { + status = parseStatus(line); + readStatus = true; + } + line = null; + } + } + } + + return new HttpResponse(status, new WsHandshakeResponse(headers)); + } + + + private int parseStatus(String line) throws DeploymentException { + // This client only understands HTTP 1. + // RFC2616 is case specific + String[] parts = line.trim().split(" "); + // CONNECT for proxy may return a 1.0 response + if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidStatus", line)); + } + try { + return Integer.parseInt(parts[1]); + } catch (NumberFormatException nfe) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidStatus", line)); + } + } + + + private void parseHeaders(String line, Map<String,List<String>> headers) { + // Treat headers as single values by default. + + int index = line.indexOf(':'); + if (index == -1) { + log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line)); + return; + } + // Header names are case insensitive so always use lower case + String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH); + // Multi-value headers are stored as a single header and the client is + // expected to handle splitting into individual values + String headerValue = line.substring(index + 1).trim(); + + List<String> values = headers.get(headerName); + if (values == null) { + values = new ArrayList<>(1); + headers.put(headerName, values); + } + values.add(headerValue); + } + + private String readLine(ByteBuffer response) { + // All ISO-8859-1 + StringBuilder sb = new StringBuilder(); + + char c = 0; + while (response.hasRemaining()) { + c = (char) response.get(); + sb.append(c); + if (c == 10) { + break; + } + } + + return sb.toString(); + } + + + private SSLEngine createSSLEngine(Map<String,Object> userProperties, String host, int port) + throws DeploymentException { + + try { + // See if a custom SSLContext has been provided + SSLContext sslContext = + (SSLContext) userProperties.get(Constants.SSL_CONTEXT_PROPERTY); + + if (sslContext == null) { + // Create the SSL Context + sslContext = SSLContext.getInstance("TLS"); + + // Trust store + String sslTrustStoreValue = + (String) userProperties.get(Constants.SSL_TRUSTSTORE_PROPERTY); + if (sslTrustStoreValue != null) { + String sslTrustStorePwdValue = (String) userProperties.get( + Constants.SSL_TRUSTSTORE_PWD_PROPERTY); + if (sslTrustStorePwdValue == null) { + sslTrustStorePwdValue = Constants.SSL_TRUSTSTORE_PWD_DEFAULT; + } + + File keyStoreFile = new File(sslTrustStoreValue); + KeyStore ks = KeyStore.getInstance("JKS"); + try (InputStream is = new FileInputStream(keyStoreFile)) { + ks.load(is, sslTrustStorePwdValue.toCharArray()); + } + + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ks); + + sslContext.init(null, tmf.getTrustManagers(), null); + } else { + sslContext.init(null, null, null); + } + } + + SSLEngine engine = sslContext.createSSLEngine(host, port); + + String sslProtocolsValue = + (String) userProperties.get(Constants.SSL_PROTOCOLS_PROPERTY); + if (sslProtocolsValue != null) { + engine.setEnabledProtocols(sslProtocolsValue.split(",")); + } + + engine.setUseClientMode(true); + + // Enable host verification + // Start with current settings (returns a copy) + SSLParameters sslParams = engine.getSSLParameters(); + // Use HTTPS since WebSocket starts over HTTP(S) + sslParams.setEndpointIdentificationAlgorithm("HTTPS"); + // Write the parameters back + engine.setSSLParameters(sslParams); + + return engine; + } catch (Exception e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.sslEngineFail"), e); + } + } + + + @Override + public long getDefaultMaxSessionIdleTimeout() { + return defaultMaxSessionIdleTimeout; + } + + + @Override + public void setDefaultMaxSessionIdleTimeout(long timeout) { + this.defaultMaxSessionIdleTimeout = timeout; + } + + + @Override + public int getDefaultMaxBinaryMessageBufferSize() { + return maxBinaryMessageBufferSize; + } + + + @Override + public void setDefaultMaxBinaryMessageBufferSize(int max) { + maxBinaryMessageBufferSize = max; + } + + + @Override + public int getDefaultMaxTextMessageBufferSize() { + return maxTextMessageBufferSize; + } + + + @Override + public void setDefaultMaxTextMessageBufferSize(int max) { + maxTextMessageBufferSize = max; + } + + + /** + * {@inheritDoc} + * + * Currently, this implementation does not support any extensions. + */ + @Override + public Set<Extension> getInstalledExtensions() { + return Collections.emptySet(); + } + + + /** + * {@inheritDoc} + * + * The default value for this implementation is -1. + */ + @Override + public long getDefaultAsyncSendTimeout() { + return defaultAsyncTimeout; + } + + + /** + * {@inheritDoc} + * + * The default value for this implementation is -1. + */ + @Override + public void setAsyncSendTimeout(long timeout) { + this.defaultAsyncTimeout = timeout; + } + + + /** + * Cleans up the resources still in use by WebSocket sessions created from + * this container. This includes closing sessions and cancelling + * {@link Future}s associated with blocking read/writes. + */ + public void destroy() { + CloseReason cr = new CloseReason( + CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown")); + + for (WsSession session : sessions.keySet()) { + try { + session.close(cr); + } catch (IOException ioe) { + log.debug(sm.getString( + "wsWebSocketContainer.sessionCloseFail", session.getId()), ioe); + } + } + + // Only unregister with AsyncChannelGroupUtil if this instance + // registered with it + if (asynchronousChannelGroup != null) { + synchronized (asynchronousChannelGroupLock) { + if (asynchronousChannelGroup != null) { + AsyncChannelGroupUtil.unregister(); + asynchronousChannelGroup = null; + } + } + } + } + + + private AsynchronousChannelGroup getAsynchronousChannelGroup() { + // Use AsyncChannelGroupUtil to share a common group amongst all + // WebSocket clients + AsynchronousChannelGroup result = asynchronousChannelGroup; + if (result == null) { + synchronized (asynchronousChannelGroupLock) { + if (asynchronousChannelGroup == null) { + asynchronousChannelGroup = AsyncChannelGroupUtil.register(); + } + result = asynchronousChannelGroup; + } + } + return result; + } + + + // ----------------------------------------------- BackgroundProcess methods + + @Override + public void backgroundProcess() { + // This method gets called once a second. + backgroundProcessCount ++; + if (backgroundProcessCount >= processPeriod) { + backgroundProcessCount = 0; + + for (WsSession wsSession : sessions.keySet()) { + wsSession.checkExpiration(); + } + } + + } + + + @Override + public void setProcessPeriod(int period) { + this.processPeriod = period; + } + + + /** + * {@inheritDoc} + * + * The default value is 10 which means session expirations are processed + * every 10 seconds. + */ + @Override + public int getProcessPeriod() { + return processPeriod; + } + + + private static class HttpResponse { + private final int status; + private final HandshakeResponse handshakeResponse; + + public HttpResponse(int status, HandshakeResponse handshakeResponse) { + this.status = status; + this.handshakeResponse = handshakeResponse; + } + + + public int getStatus() { + return status; + } + + + public HandshakeResponse getHandshakeResponse() { + return handshakeResponse; + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/Constants.java b/src/java/nginx/unit/websocket/pojo/Constants.java new file mode 100644 index 00000000..93cdecc7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/Constants.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +/** + * Internal implementation constants. + */ +public class Constants { + + public static final String POJO_PATH_PARAM_KEY = + "nginx.unit.websocket.pojo.PojoEndpoint.pathParams"; + public static final String POJO_METHOD_MAPPING_KEY = + "nginx.unit.websocket.pojo.PojoEndpoint.methodMapping"; + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/pojo/LocalStrings.properties b/src/java/nginx/unit/websocket/pojo/LocalStrings.properties new file mode 100644 index 00000000..00ab7e6b --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/LocalStrings.properties @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pojoEndpointBase.closeSessionFail=Failed to close WebSocket session during error handling +pojoEndpointBase.onCloseFail=Failed to call onClose method of POJO end point for POJO of type [{0}] +pojoEndpointBase.onError=No error handling configured for [{0}] and the following error occurred +pojoEndpointBase.onErrorFail=Failed to call onError method of POJO end point for POJO of type [{0}] +pojoEndpointBase.onOpenFail=Failed to call onOpen method of POJO end point for POJO of type [{0}] +pojoEndpointServer.getPojoInstanceFail=Failed to create instance of POJO of type [{0}] +pojoMethodMapping.decodePathParamFail=Failed to decode path parameter value [{0}] to expected type [{1}] +pojoMethodMapping.duplicateAnnotation=Duplicate annotations [{0}] present on class [{1}] +pojoMethodMapping.duplicateLastParam=Multiple boolean (last) parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicateMessageParam=Multiple message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicatePongMessageParam=Multiple PongMessage parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicateSessionParam=Multiple session parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.invalidDecoder=The specified decoder of type [{0}] could not be instantiated +pojoMethodMapping.invalidPathParamType=Parameters annotated with @PathParam may only be Strings, Java primitives or a boxed version thereof +pojoMethodMapping.methodNotPublic=The annotated method [{0}] is not public +pojoMethodMapping.noPayload=No payload parameter present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.onErrorNoThrowable=No Throwable parameter was present on the method [{0}] of class [{1}] that was annotated with OnError +pojoMethodMapping.paramWithoutAnnotation=A parameter of type [{0}] was found on method[{1}] of class [{2}] that did not have a @PathParam annotation +pojoMethodMapping.partialInputStream=Invalid InputStream and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialObject=Invalid Object and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialPong=Invalid PongMessage and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialReader=Invalid Reader and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.pongWithPayload=Invalid PongMessage and Message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMessageHandlerWhole.decodeIoFail=IO error while decoding message +pojoMessageHandlerWhole.maxBufferSize=The maximum supported message size for this implementation is Integer.MAX_VALUE diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java new file mode 100644 index 00000000..be679a35 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.util.Map; +import java.util.Set; + +import javax.websocket.CloseReason; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.res.StringManager; + +/** + * Base implementation (client and server have different concrete + * implementations) of the wrapper that converts a POJO instance into a + * WebSocket endpoint instance. + */ +public abstract class PojoEndpointBase extends Endpoint { + + private final Log log = LogFactory.getLog(PojoEndpointBase.class); // must not be static + private static final StringManager sm = StringManager.getManager(PojoEndpointBase.class); + + private Object pojo; + private Map<String,String> pathParameters; + private PojoMethodMapping methodMapping; + + + protected final void doOnOpen(Session session, EndpointConfig config) { + PojoMethodMapping methodMapping = getMethodMapping(); + Object pojo = getPojo(); + Map<String,String> pathParameters = getPathParameters(); + + // Add message handlers before calling onOpen since that may trigger a + // message which in turn could trigger a response and/or close the + // session + for (MessageHandler mh : methodMapping.getMessageHandlers(pojo, + pathParameters, session, config)) { + session.addMessageHandler(mh); + } + + if (methodMapping.getOnOpen() != null) { + try { + methodMapping.getOnOpen().invoke(pojo, + methodMapping.getOnOpenArgs( + pathParameters, session, config)); + + } catch (IllegalAccessException e) { + // Reflection related problems + log.error(sm.getString( + "pojoEndpointBase.onOpenFail", + pojo.getClass().getName()), e); + handleOnOpenOrCloseError(session, e); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + handleOnOpenOrCloseError(session, cause); + } catch (Throwable t) { + handleOnOpenOrCloseError(session, t); + } + } + } + + + private void handleOnOpenOrCloseError(Session session, Throwable t) { + // If really fatal - re-throw + ExceptionUtils.handleThrowable(t); + + // Trigger the error handler and close the session + onError(session, t); + try { + session.close(); + } catch (IOException ioe) { + log.warn(sm.getString("pojoEndpointBase.closeSessionFail"), ioe); + } + } + + @Override + public final void onClose(Session session, CloseReason closeReason) { + + if (methodMapping.getOnClose() != null) { + try { + methodMapping.getOnClose().invoke(pojo, + methodMapping.getOnCloseArgs(pathParameters, session, closeReason)); + } catch (Throwable t) { + log.error(sm.getString("pojoEndpointBase.onCloseFail", + pojo.getClass().getName()), t); + handleOnOpenOrCloseError(session, t); + } + } + + // Trigger the destroy method for any associated decoders + Set<MessageHandler> messageHandlers = session.getMessageHandlers(); + for (MessageHandler messageHandler : messageHandlers) { + if (messageHandler instanceof PojoMessageHandlerWholeBase<?>) { + ((PojoMessageHandlerWholeBase<?>) messageHandler).onClose(); + } + } + } + + + @Override + public final void onError(Session session, Throwable throwable) { + + if (methodMapping.getOnError() == null) { + log.error(sm.getString("pojoEndpointBase.onError", + pojo.getClass().getName()), throwable); + } else { + try { + methodMapping.getOnError().invoke( + pojo, + methodMapping.getOnErrorArgs(pathParameters, session, + throwable)); + } catch (Throwable t) { + ExceptionUtils.handleThrowable(t); + log.error(sm.getString("pojoEndpointBase.onErrorFail", + pojo.getClass().getName()), t); + } + } + } + + protected Object getPojo() { return pojo; } + protected void setPojo(Object pojo) { this.pojo = pojo; } + + + protected Map<String,String> getPathParameters() { return pathParameters; } + protected void setPathParameters(Map<String,String> pathParameters) { + this.pathParameters = pathParameters; + } + + + protected PojoMethodMapping getMethodMapping() { return methodMapping; } + protected void setMethodMapping(PojoMethodMapping methodMapping) { + this.methodMapping = methodMapping; + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java new file mode 100644 index 00000000..6e569487 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.util.Collections; +import java.util.List; + +import javax.websocket.Decoder; +import javax.websocket.DeploymentException; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + + +/** + * Wrapper class for instances of POJOs annotated with + * {@link javax.websocket.ClientEndpoint} so they appear as standard + * {@link javax.websocket.Endpoint} instances. + */ +public class PojoEndpointClient extends PojoEndpointBase { + + public PojoEndpointClient(Object pojo, + List<Class<? extends Decoder>> decoders) throws DeploymentException { + setPojo(pojo); + setMethodMapping( + new PojoMethodMapping(pojo.getClass(), decoders, null)); + setPathParameters(Collections.<String,String>emptyMap()); + } + + @Override + public void onOpen(Session session, EndpointConfig config) { + doOnOpen(session, config); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java new file mode 100644 index 00000000..499f8274 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.util.Map; + +import javax.websocket.EndpointConfig; +import javax.websocket.Session; +import javax.websocket.server.ServerEndpointConfig; + +import org.apache.tomcat.util.res.StringManager; + +/** + * Wrapper class for instances of POJOs annotated with + * {@link javax.websocket.server.ServerEndpoint} so they appear as standard + * {@link javax.websocket.Endpoint} instances. + */ +public class PojoEndpointServer extends PojoEndpointBase { + + private static final StringManager sm = + StringManager.getManager(PojoEndpointServer.class); + + @Override + public void onOpen(Session session, EndpointConfig endpointConfig) { + + ServerEndpointConfig sec = (ServerEndpointConfig) endpointConfig; + + Object pojo; + try { + pojo = sec.getConfigurator().getEndpointInstance( + sec.getEndpointClass()); + } catch (InstantiationException e) { + throw new IllegalArgumentException(sm.getString( + "pojoEndpointServer.getPojoInstanceFail", + sec.getEndpointClass().getName()), e); + } + setPojo(pojo); + + @SuppressWarnings("unchecked") + Map<String,String> pathParameters = + (Map<String, String>) sec.getUserProperties().get( + Constants.POJO_PATH_PARAM_KEY); + setPathParameters(pathParameters); + + PojoMethodMapping methodMapping = + (PojoMethodMapping) sec.getUserProperties().get( + Constants.POJO_METHOD_MAPPING_KEY); + setMethodMapping(methodMapping); + + doOnOpen(session, endpointConfig); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java new file mode 100644 index 00000000..b72d719a --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.EncodeException; +import javax.websocket.MessageHandler; +import javax.websocket.RemoteEndpoint; +import javax.websocket.Session; + +import org.apache.tomcat.util.ExceptionUtils; +import nginx.unit.websocket.WrappedMessageHandler; + +/** + * Common implementation code for the POJO message handlers. + * + * @param <T> The type of message to handle + */ +public abstract class PojoMessageHandlerBase<T> + implements WrappedMessageHandler { + + protected final Object pojo; + protected final Method method; + protected final Session session; + protected final Object[] params; + protected final int indexPayload; + protected final boolean convert; + protected final int indexSession; + protected final long maxMessageSize; + + public PojoMessageHandlerBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexSession, long maxMessageSize) { + this.pojo = pojo; + this.method = method; + // TODO: The method should already be accessible here but the following + // code seems to be necessary in some as yet not fully understood cases. + try { + this.method.setAccessible(true); + } catch (Exception e) { + // It is better to make sure the method is accessible, but + // ignore exceptions and hope for the best + } + this.session = session; + this.params = params; + this.indexPayload = indexPayload; + this.convert = convert; + this.indexSession = indexSession; + this.maxMessageSize = maxMessageSize; + } + + + protected final void processResult(Object result) { + if (result == null) { + return; + } + + RemoteEndpoint.Basic remoteEndpoint = session.getBasicRemote(); + try { + if (result instanceof String) { + remoteEndpoint.sendText((String) result); + } else if (result instanceof ByteBuffer) { + remoteEndpoint.sendBinary((ByteBuffer) result); + } else if (result instanceof byte[]) { + remoteEndpoint.sendBinary(ByteBuffer.wrap((byte[]) result)); + } else { + remoteEndpoint.sendObject(result); + } + } catch (IOException | EncodeException ioe) { + throw new IllegalStateException(ioe); + } + } + + + /** + * Expose the POJO if it is a message handler so the Session is able to + * match requests to remove handlers if the original handler has been + * wrapped. + */ + @Override + public final MessageHandler getWrappedHandler() { + if (pojo instanceof MessageHandler) { + return (MessageHandler) pojo; + } else { + return null; + } + } + + + @Override + public final long getMaxMessageSize() { + return maxMessageSize; + } + + + protected final void handlePojoMethodException(Throwable t) { + t = ExceptionUtils.unwrapInvocationTargetException(t); + ExceptionUtils.handleThrowable(t); + if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } else { + throw new RuntimeException(t.getMessage(), t); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java new file mode 100644 index 00000000..d6f37724 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.DecodeException; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import nginx.unit.websocket.WsSession; + +/** + * Common implementation code for the POJO partial message handlers. All + * the real work is done in this class and in the superclass. + * + * @param <T> The type of message to handle + */ +public abstract class PojoMessageHandlerPartialBase<T> + extends PojoMessageHandlerBase<T> implements MessageHandler.Partial<T> { + + private final int indexBoolean; + + public PojoMessageHandlerPartialBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, + boolean convert, int indexBoolean, int indexSession, + long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + this.indexBoolean = indexBoolean; + } + + + @Override + public final void onMessage(T message, boolean last) { + if (params.length == 1 && params[0] instanceof DecodeException) { + ((WsSession) session).getLocal().onError(session, + (DecodeException) params[0]); + return; + } + Object[] parameters = params.clone(); + if (indexBoolean != -1) { + parameters[indexBoolean] = Boolean.valueOf(last); + } + if (indexSession != -1) { + parameters[indexSession] = session; + } + if (convert) { + parameters[indexPayload] = ((ByteBuffer) message).array(); + } else { + parameters[indexPayload] = message; + } + Object result = null; + try { + result = method.invoke(pojo, parameters); + } catch (IllegalAccessException | InvocationTargetException e) { + handlePojoMethodException(e); + } + processResult(result); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java new file mode 100644 index 00000000..1d334017 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.Session; + +/** + * ByteBuffer specific concrete implementation for handling partial messages. + */ +public class PojoMessageHandlerPartialBinary + extends PojoMessageHandlerPartialBase<ByteBuffer> { + + public PojoMessageHandlerPartialBinary(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexBoolean, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, indexBoolean, + indexSession, maxMessageSize); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java new file mode 100644 index 00000000..8f7c1a0d --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; + +import javax.websocket.Session; + +/** + * Text specific concrete implementation for handling partial messages. + */ +public class PojoMessageHandlerPartialText + extends PojoMessageHandlerPartialBase<String> { + + public PojoMessageHandlerPartialText(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexBoolean, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, indexBoolean, + indexSession, maxMessageSize); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java new file mode 100644 index 00000000..23333eb7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import javax.websocket.DecodeException; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import nginx.unit.websocket.WsSession; + +/** + * Common implementation code for the POJO whole message handlers. All the real + * work is done in this class and in the superclass. + * + * @param <T> The type of message to handle + */ +public abstract class PojoMessageHandlerWholeBase<T> + extends PojoMessageHandlerBase<T> implements MessageHandler.Whole<T> { + + public PojoMessageHandlerWholeBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, + boolean convert, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + } + + + @Override + public final void onMessage(T message) { + + if (params.length == 1 && params[0] instanceof DecodeException) { + ((WsSession) session).getLocal().onError(session, + (DecodeException) params[0]); + return; + } + + // Can this message be decoded? + Object payload; + try { + payload = decode(message); + } catch (DecodeException de) { + ((WsSession) session).getLocal().onError(session, de); + return; + } + + if (payload == null) { + // Not decoded. Convert if required. + if (convert) { + payload = convert(message); + } else { + payload = message; + } + } + + Object[] parameters = params.clone(); + if (indexSession != -1) { + parameters[indexSession] = session; + } + parameters[indexPayload] = payload; + + Object result = null; + try { + result = method.invoke(pojo, parameters); + } catch (IllegalAccessException | InvocationTargetException e) { + handlePojoMethodException(e); + } + processResult(result); + } + + protected Object convert(T message) { + return message; + } + + + protected abstract Object decode(T message) throws DecodeException; + protected abstract void onClose(); +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java new file mode 100644 index 00000000..07ff0648 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Binary; +import javax.websocket.Decoder.BinaryStream; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; + +/** + * ByteBuffer specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholeBinary + extends PojoMessageHandlerWholeBase<ByteBuffer> { + + private static final StringManager sm = + StringManager.getManager(PojoMessageHandlerWholeBinary.class); + + private final List<Decoder> decoders = new ArrayList<>(); + + private final boolean isForInputStream; + + public PojoMessageHandlerWholeBinary(Object pojo, Method method, + Session session, EndpointConfig config, + List<Class<? extends Decoder>> decoderClazzes, Object[] params, + int indexPayload, boolean convert, int indexSession, + boolean isForInputStream, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + + // Update binary text size handled by session + if (maxMessageSize > -1 && maxMessageSize > session.getMaxBinaryMessageBufferSize()) { + if (maxMessageSize > Integer.MAX_VALUE) { + throw new IllegalArgumentException(sm.getString( + "pojoMessageHandlerWhole.maxBufferSize")); + } + session.setMaxBinaryMessageBufferSize((int) maxMessageSize); + } + + try { + if (decoderClazzes != null) { + for (Class<? extends Decoder> decoderClazz : decoderClazzes) { + if (Binary.class.isAssignableFrom(decoderClazz)) { + Binary<?> decoder = (Binary<?>) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else if (BinaryStream.class.isAssignableFrom( + decoderClazz)) { + BinaryStream<?> decoder = (BinaryStream<?>) + decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else { + // Text decoder - ignore it + } + } + } + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException(e); + } + this.isForInputStream = isForInputStream; + } + + + @Override + protected Object decode(ByteBuffer message) throws DecodeException { + for (Decoder decoder : decoders) { + if (decoder instanceof Binary) { + if (((Binary<?>) decoder).willDecode(message)) { + return ((Binary<?>) decoder).decode(message); + } + } else { + byte[] array = new byte[message.limit() - message.position()]; + message.get(array); + ByteArrayInputStream bais = new ByteArrayInputStream(array); + try { + return ((BinaryStream<?>) decoder).decode(bais); + } catch (IOException ioe) { + throw new DecodeException(message, sm.getString( + "pojoMessageHandlerWhole.decodeIoFail"), ioe); + } + } + } + return null; + } + + + @Override + protected Object convert(ByteBuffer message) { + byte[] array = new byte[message.remaining()]; + message.get(array); + if (isForInputStream) { + return new ByteArrayInputStream(array); + } else { + return array; + } + } + + + @Override + protected void onClose() { + for (Decoder decoder : decoders) { + decoder.destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java new file mode 100644 index 00000000..bdedd7de --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; + +import javax.websocket.PongMessage; +import javax.websocket.Session; + +/** + * PongMessage specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholePong + extends PojoMessageHandlerWholeBase<PongMessage> { + + public PojoMessageHandlerWholePong(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexSession) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, -1); + } + + @Override + protected Object decode(PongMessage message) { + // Never decoded + return null; + } + + + @Override + protected void onClose() { + // NO-OP + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java new file mode 100644 index 00000000..59007349 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.io.StringReader; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Text; +import javax.websocket.Decoder.TextStream; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.Util; + + +/** + * Text specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholeText + extends PojoMessageHandlerWholeBase<String> { + + private static final StringManager sm = + StringManager.getManager(PojoMessageHandlerWholeText.class); + + private final List<Decoder> decoders = new ArrayList<>(); + private final Class<?> primitiveType; + + public PojoMessageHandlerWholeText(Object pojo, Method method, + Session session, EndpointConfig config, + List<Class<? extends Decoder>> decoderClazzes, Object[] params, + int indexPayload, boolean convert, int indexSession, + long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + + // Update max text size handled by session + if (maxMessageSize > -1 && maxMessageSize > session.getMaxTextMessageBufferSize()) { + if (maxMessageSize > Integer.MAX_VALUE) { + throw new IllegalArgumentException(sm.getString( + "pojoMessageHandlerWhole.maxBufferSize")); + } + session.setMaxTextMessageBufferSize((int) maxMessageSize); + } + + // Check for primitives + Class<?> type = method.getParameterTypes()[indexPayload]; + if (Util.isPrimitive(type)) { + primitiveType = type; + return; + } else { + primitiveType = null; + } + + try { + if (decoderClazzes != null) { + for (Class<? extends Decoder> decoderClazz : decoderClazzes) { + if (Text.class.isAssignableFrom(decoderClazz)) { + Text<?> decoder = (Text<?>) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else if (TextStream.class.isAssignableFrom( + decoderClazz)) { + TextStream<?> decoder = + (TextStream<?>) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else { + // Binary decoder - ignore it + } + } + } + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException(e); + } + } + + + @Override + protected Object decode(String message) throws DecodeException { + // Handle primitives + if (primitiveType != null) { + return Util.coerceToType(primitiveType, message); + } + // Handle full decoders + for (Decoder decoder : decoders) { + if (decoder instanceof Text) { + if (((Text<?>) decoder).willDecode(message)) { + return ((Text<?>) decoder).decode(message); + } + } else { + StringReader r = new StringReader(message); + try { + return ((TextStream<?>) decoder).decode(r); + } catch (IOException ioe) { + throw new DecodeException(message, sm.getString( + "pojoMessageHandlerWhole.decodeIoFail"), ioe); + } + } + } + return null; + } + + + @Override + protected Object convert(String message) { + return new StringReader(message); + } + + + @Override + protected void onClose() { + for (Decoder decoder : decoders) { + decoder.destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java new file mode 100644 index 00000000..2385b5c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java @@ -0,0 +1,731 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.InputStream; +import java.io.Reader; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.websocket.CloseReason; +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.DeploymentException; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.OnClose; +import javax.websocket.OnError; +import javax.websocket.OnMessage; +import javax.websocket.OnOpen; +import javax.websocket.PongMessage; +import javax.websocket.Session; +import javax.websocket.server.PathParam; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.DecoderEntry; +import nginx.unit.websocket.Util; +import nginx.unit.websocket.Util.DecoderMatch; + +/** + * For a POJO class annotated with + * {@link javax.websocket.server.ServerEndpoint}, an instance of this class + * creates and caches the method handler, method information and parameter + * information for the onXXX calls. + */ +public class PojoMethodMapping { + + private static final StringManager sm = + StringManager.getManager(PojoMethodMapping.class); + + private final Method onOpen; + private final Method onClose; + private final Method onError; + private final PojoPathParam[] onOpenParams; + private final PojoPathParam[] onCloseParams; + private final PojoPathParam[] onErrorParams; + private final List<MessageHandlerInfo> onMessage = new ArrayList<>(); + private final String wsPath; + + + public PojoMethodMapping(Class<?> clazzPojo, + List<Class<? extends Decoder>> decoderClazzes, String wsPath) + throws DeploymentException { + + this.wsPath = wsPath; + + List<DecoderEntry> decoders = Util.getDecoders(decoderClazzes); + Method open = null; + Method close = null; + Method error = null; + Method[] clazzPojoMethods = null; + Class<?> currentClazz = clazzPojo; + while (!currentClazz.equals(Object.class)) { + Method[] currentClazzMethods = currentClazz.getDeclaredMethods(); + if (currentClazz == clazzPojo) { + clazzPojoMethods = currentClazzMethods; + } + for (Method method : currentClazzMethods) { + if (method.getAnnotation(OnOpen.class) != null) { + checkPublic(method); + if (open == null) { + open = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(open, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnOpen.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnClose.class) != null) { + checkPublic(method); + if (close == null) { + close = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(close, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnClose.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnError.class) != null) { + checkPublic(method); + if (error == null) { + error = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(error, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnError.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnMessage.class) != null) { + checkPublic(method); + MessageHandlerInfo messageHandler = new MessageHandlerInfo(method, decoders); + boolean found = false; + for (MessageHandlerInfo otherMessageHandler : onMessage) { + if (messageHandler.targetsSameWebSocketMessageType(otherMessageHandler)) { + found = true; + if (currentClazz == clazzPojo || + !isMethodOverride(messageHandler.m, otherMessageHandler.m)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnMessage.class, currentClazz)); + } + } + } + if (!found) { + onMessage.add(messageHandler); + } + } else { + // Method not annotated + } + } + currentClazz = currentClazz.getSuperclass(); + } + // If the methods are not on clazzPojo and they are overridden + // by a non annotated method in clazzPojo, they should be ignored + if (open != null && open.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, open, OnOpen.class)) { + open = null; + } + } + if (close != null && close.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, close, OnClose.class)) { + close = null; + } + } + if (error != null && error.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, error, OnError.class)) { + error = null; + } + } + List<MessageHandlerInfo> overriddenOnMessage = new ArrayList<>(); + for (MessageHandlerInfo messageHandler : onMessage) { + if (messageHandler.m.getDeclaringClass() != clazzPojo + && isOverridenWithoutAnnotation(clazzPojoMethods, messageHandler.m, OnMessage.class)) { + overriddenOnMessage.add(messageHandler); + } + } + for (MessageHandlerInfo messageHandler : overriddenOnMessage) { + onMessage.remove(messageHandler); + } + this.onOpen = open; + this.onClose = close; + this.onError = error; + onOpenParams = getPathParams(onOpen, MethodType.ON_OPEN); + onCloseParams = getPathParams(onClose, MethodType.ON_CLOSE); + onErrorParams = getPathParams(onError, MethodType.ON_ERROR); + } + + + private void checkPublic(Method m) throws DeploymentException { + if (!Modifier.isPublic(m.getModifiers())) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.methodNotPublic", m.getName())); + } + } + + + private boolean isMethodOverride(Method method1, Method method2) { + return method1.getName().equals(method2.getName()) + && method1.getReturnType().equals(method2.getReturnType()) + && Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes()); + } + + + private boolean isOverridenWithoutAnnotation(Method[] methods, + Method superclazzMethod, Class<? extends Annotation> annotation) { + for (Method method : methods) { + if (isMethodOverride(method, superclazzMethod) + && (method.getAnnotation(annotation) == null)) { + return true; + } + } + return false; + } + + + public String getWsPath() { + return wsPath; + } + + + public Method getOnOpen() { + return onOpen; + } + + + public Object[] getOnOpenArgs(Map<String,String> pathParameters, + Session session, EndpointConfig config) throws DecodeException { + return buildArgs(onOpenParams, pathParameters, session, config, null, + null); + } + + + public Method getOnClose() { + return onClose; + } + + + public Object[] getOnCloseArgs(Map<String,String> pathParameters, + Session session, CloseReason closeReason) throws DecodeException { + return buildArgs(onCloseParams, pathParameters, session, null, null, + closeReason); + } + + + public Method getOnError() { + return onError; + } + + + public Object[] getOnErrorArgs(Map<String,String> pathParameters, + Session session, Throwable throwable) throws DecodeException { + return buildArgs(onErrorParams, pathParameters, session, null, + throwable, null); + } + + + public boolean hasMessageHandlers() { + return !onMessage.isEmpty(); + } + + + public Set<MessageHandler> getMessageHandlers(Object pojo, + Map<String,String> pathParameters, Session session, + EndpointConfig config) { + Set<MessageHandler> result = new HashSet<>(); + for (MessageHandlerInfo messageMethod : onMessage) { + result.addAll(messageMethod.getMessageHandlers(pojo, pathParameters, + session, config)); + } + return result; + } + + + private static PojoPathParam[] getPathParams(Method m, + MethodType methodType) throws DeploymentException { + if (m == null) { + return new PojoPathParam[0]; + } + boolean foundThrowable = false; + Class<?>[] types = m.getParameterTypes(); + Annotation[][] paramsAnnotations = m.getParameterAnnotations(); + PojoPathParam[] result = new PojoPathParam[types.length]; + for (int i = 0; i < types.length; i++) { + Class<?> type = types[i]; + if (type.equals(Session.class)) { + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_OPEN && + type.equals(EndpointConfig.class)) { + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_ERROR + && type.equals(Throwable.class)) { + foundThrowable = true; + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_CLOSE && + type.equals(CloseReason.class)) { + result[i] = new PojoPathParam(type, null); + } else { + Annotation[] paramAnnotations = paramsAnnotations[i]; + for (Annotation paramAnnotation : paramAnnotations) { + if (paramAnnotation.annotationType().equals( + PathParam.class)) { + // Check that the type is valid. "0" coerces to every + // valid type + try { + Util.coerceToType(type, "0"); + } catch (IllegalArgumentException iae) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.invalidPathParamType"), + iae); + } + result[i] = new PojoPathParam(type, + ((PathParam) paramAnnotation).value()); + break; + } + } + // Parameters without annotations are not permitted + if (result[i] == null) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.paramWithoutAnnotation", + type, m.getName(), m.getClass().getName())); + } + } + } + if (methodType == MethodType.ON_ERROR && !foundThrowable) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.onErrorNoThrowable", + m.getName(), m.getDeclaringClass().getName())); + } + return result; + } + + + private static Object[] buildArgs(PojoPathParam[] pathParams, + Map<String,String> pathParameters, Session session, + EndpointConfig config, Throwable throwable, CloseReason closeReason) + throws DecodeException { + Object[] result = new Object[pathParams.length]; + for (int i = 0; i < pathParams.length; i++) { + Class<?> type = pathParams[i].getType(); + if (type.equals(Session.class)) { + result[i] = session; + } else if (type.equals(EndpointConfig.class)) { + result[i] = config; + } else if (type.equals(Throwable.class)) { + result[i] = throwable; + } else if (type.equals(CloseReason.class)) { + result[i] = closeReason; + } else { + String name = pathParams[i].getName(); + String value = pathParameters.get(name); + try { + result[i] = Util.coerceToType(type, value); + } catch (Exception e) { + throw new DecodeException(value, sm.getString( + "pojoMethodMapping.decodePathParamFail", + value, type), e); + } + } + } + return result; + } + + + private static class MessageHandlerInfo { + + private final Method m; + private int indexString = -1; + private int indexByteArray = -1; + private int indexByteBuffer = -1; + private int indexPong = -1; + private int indexBoolean = -1; + private int indexSession = -1; + private int indexInputStream = -1; + private int indexReader = -1; + private int indexPrimitive = -1; + private Class<?> primitiveType = null; + private Map<Integer,PojoPathParam> indexPathParams = new HashMap<>(); + private int indexPayload = -1; + private DecoderMatch decoderMatch = null; + private long maxMessageSize = -1; + + public MessageHandlerInfo(Method m, List<DecoderEntry> decoderEntries) { + this.m = m; + + Class<?>[] types = m.getParameterTypes(); + Annotation[][] paramsAnnotations = m.getParameterAnnotations(); + + for (int i = 0; i < types.length; i++) { + boolean paramFound = false; + Annotation[] paramAnnotations = paramsAnnotations[i]; + for (Annotation paramAnnotation : paramAnnotations) { + if (paramAnnotation.annotationType().equals( + PathParam.class)) { + indexPathParams.put( + Integer.valueOf(i), new PojoPathParam(types[i], + ((PathParam) paramAnnotation).value())); + paramFound = true; + break; + } + } + if (paramFound) { + continue; + } + if (String.class.isAssignableFrom(types[i])) { + if (indexString == -1) { + indexString = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Reader.class.isAssignableFrom(types[i])) { + if (indexReader == -1) { + indexReader = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (boolean.class == types[i]) { + if (indexBoolean == -1) { + indexBoolean = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateLastParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (ByteBuffer.class.isAssignableFrom(types[i])) { + if (indexByteBuffer == -1) { + indexByteBuffer = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (byte[].class == types[i]) { + if (indexByteArray == -1) { + indexByteArray = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (InputStream.class.isAssignableFrom(types[i])) { + if (indexInputStream == -1) { + indexInputStream = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Util.isPrimitive(types[i])) { + if (indexPrimitive == -1) { + indexPrimitive = i; + primitiveType = types[i]; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Session.class.isAssignableFrom(types[i])) { + if (indexSession == -1) { + indexSession = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateSessionParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (PongMessage.class.isAssignableFrom(types[i])) { + if (indexPong == -1) { + indexPong = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicatePongMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else { + if (decoderMatch != null && decoderMatch.hasMatches()) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + decoderMatch = new DecoderMatch(types[i], decoderEntries); + + if (decoderMatch.hasMatches()) { + indexPayload = i; + } + } + } + + // Additional checks required + if (indexString != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexString; + } + } + if (indexReader != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexReader; + } + } + if (indexByteArray != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexByteArray; + } + } + if (indexByteBuffer != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexByteBuffer; + } + } + if (indexInputStream != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexInputStream; + } + } + if (indexPrimitive != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexPrimitive; + } + } + if (indexPong != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.pongWithPayload", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexPong; + } + } + if (indexPayload == -1 && indexPrimitive == -1 && + indexBoolean != -1) { + // The boolean we found is a payload, not a last flag + indexPayload = indexBoolean; + indexPrimitive = indexBoolean; + primitiveType = Boolean.TYPE; + indexBoolean = -1; + } + if (indexPayload == -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.noPayload", + m.getName(), m.getDeclaringClass().getName())); + } + if (indexPong != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialPong", + m.getName(), m.getDeclaringClass().getName())); + } + if(indexReader != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialReader", + m.getName(), m.getDeclaringClass().getName())); + } + if(indexInputStream != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialInputStream", + m.getName(), m.getDeclaringClass().getName())); + } + if (decoderMatch != null && decoderMatch.hasMatches() && + indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialObject", + m.getName(), m.getDeclaringClass().getName())); + } + + maxMessageSize = m.getAnnotation(OnMessage.class).maxMessageSize(); + } + + + public boolean targetsSameWebSocketMessageType(MessageHandlerInfo otherHandler) { + if (otherHandler == null) { + return false; + } + if (indexByteArray >= 0 && otherHandler.indexByteArray >= 0) { + return true; + } + if (indexByteBuffer >= 0 && otherHandler.indexByteBuffer >= 0) { + return true; + } + if (indexInputStream >= 0 && otherHandler.indexInputStream >= 0) { + return true; + } + if (indexPong >= 0 && otherHandler.indexPong >= 0) { + return true; + } + if (indexPrimitive >= 0 && otherHandler.indexPrimitive >= 0 + && primitiveType == otherHandler.primitiveType) { + return true; + } + if (indexReader >= 0 && otherHandler.indexReader >= 0) { + return true; + } + if (indexString >= 0 && otherHandler.indexString >= 0) { + return true; + } + if (decoderMatch != null && otherHandler.decoderMatch != null + && decoderMatch.getTarget().equals(otherHandler.decoderMatch.getTarget())) { + return true; + } + return false; + } + + + public Set<MessageHandler> getMessageHandlers(Object pojo, + Map<String,String> pathParameters, Session session, + EndpointConfig config) { + Object[] params = new Object[m.getParameterTypes().length]; + + for (Map.Entry<Integer,PojoPathParam> entry : + indexPathParams.entrySet()) { + PojoPathParam pathParam = entry.getValue(); + String valueString = pathParameters.get(pathParam.getName()); + Object value = null; + try { + value = Util.coerceToType(pathParam.getType(), valueString); + } catch (Exception e) { + DecodeException de = new DecodeException(valueString, + sm.getString( + "pojoMethodMapping.decodePathParamFail", + valueString, pathParam.getType()), e); + params = new Object[] { de }; + break; + } + params[entry.getKey().intValue()] = value; + } + + Set<MessageHandler> results = new HashSet<>(2); + if (indexBoolean == -1) { + // Basic + if (indexString != -1 || indexPrimitive != -1) { + MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, + session, config, null, params, indexPayload, false, + indexSession, maxMessageSize); + results.add(mh); + } else if (indexReader != -1) { + MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, + session, config, null, params, indexReader, true, + indexSession, maxMessageSize); + results.add(mh); + } else if (indexByteArray != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexByteArray, + true, indexSession, false, maxMessageSize); + results.add(mh); + } else if (indexByteBuffer != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexByteBuffer, + false, indexSession, false, maxMessageSize); + results.add(mh); + } else if (indexInputStream != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexInputStream, + true, indexSession, true, maxMessageSize); + results.add(mh); + } else if (decoderMatch != null && decoderMatch.hasMatches()) { + if (decoderMatch.getBinaryDecoders().size() > 0) { + MessageHandler mh = new PojoMessageHandlerWholeBinary( + pojo, m, session, config, + decoderMatch.getBinaryDecoders(), params, + indexPayload, true, indexSession, true, + maxMessageSize); + results.add(mh); + } + if (decoderMatch.getTextDecoders().size() > 0) { + MessageHandler mh = new PojoMessageHandlerWholeText( + pojo, m, session, config, + decoderMatch.getTextDecoders(), params, + indexPayload, true, indexSession, maxMessageSize); + results.add(mh); + } + } else { + MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m, + session, params, indexPong, false, indexSession); + results.add(mh); + } + } else { + // ASync + if (indexString != -1) { + MessageHandler mh = new PojoMessageHandlerPartialText(pojo, + m, session, params, indexString, false, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } else if (indexByteArray != -1) { + MessageHandler mh = new PojoMessageHandlerPartialBinary( + pojo, m, session, params, indexByteArray, true, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } else { + MessageHandler mh = new PojoMessageHandlerPartialBinary( + pojo, m, session, params, indexByteBuffer, false, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } + } + return results; + } + } + + + private enum MethodType { + ON_OPEN, + ON_CLOSE, + ON_ERROR + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoPathParam.java b/src/java/nginx/unit/websocket/pojo/PojoPathParam.java new file mode 100644 index 00000000..859b6d68 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoPathParam.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +/** + * Stores the parameter type and name for a parameter that needs to be passed to + * an onXxx method of {@link javax.websocket.Endpoint}. The name is only present + * for parameters annotated with + * {@link javax.websocket.server.PathParam}. For the + * {@link javax.websocket.Session} and {@link java.lang.Throwable} parameters, + * {@link #getName()} will always return <code>null</code>. + */ +public class PojoPathParam { + + private final Class<?> type; + private final String name; + + + public PojoPathParam(Class<?> type, String name) { + this.type = type; + this.name = name; + } + + + public Class<?> getType() { + return type; + } + + + public String getName() { + return name; + } +} diff --git a/src/java/nginx/unit/websocket/pojo/package-info.java b/src/java/nginx/unit/websocket/pojo/package-info.java new file mode 100644 index 00000000..39cf80c8 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package provides the necessary plumbing to convert an annotated POJO + * into a WebSocket {@link javax.websocket.Endpoint}. + */ +package nginx.unit.websocket.pojo; diff --git a/src/java/nginx/unit/websocket/server/Constants.java b/src/java/nginx/unit/websocket/server/Constants.java new file mode 100644 index 00000000..5210c4ba --- /dev/null +++ b/src/java/nginx/unit/websocket/server/Constants.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +/** + * Internal implementation constants. + */ +public class Constants { + + public static final String BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.binaryBufferSize"; + public static final String TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.textBufferSize"; + public static final String ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.noAddAfterHandshake"; + + public static final String SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE = + "javax.websocket.server.ServerContainer"; + + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java b/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java new file mode 100644 index 00000000..43ffe2bc --- /dev/null +++ b/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.server.HandshakeRequest; +import javax.websocket.server.ServerEndpointConfig; + +public class DefaultServerEndpointConfigurator + extends ServerEndpointConfig.Configurator { + + @Override + public <T> T getEndpointInstance(Class<T> clazz) + throws InstantiationException { + try { + return clazz.getConstructor().newInstance(); + } catch (InstantiationException e) { + throw e; + } catch (ReflectiveOperationException e) { + InstantiationException ie = new InstantiationException(); + ie.initCause(e); + throw ie; + } + } + + + @Override + public String getNegotiatedSubprotocol(List<String> supported, + List<String> requested) { + + for (String request : requested) { + if (supported.contains(request)) { + return request; + } + } + return ""; + } + + + @Override + public List<Extension> getNegotiatedExtensions(List<Extension> installed, + List<Extension> requested) { + Set<String> installedNames = new HashSet<>(); + for (Extension e : installed) { + installedNames.add(e.getName()); + } + List<Extension> result = new ArrayList<>(); + for (Extension request : requested) { + if (installedNames.contains(request.getName())) { + result.add(request); + } + } + return result; + } + + + @Override + public boolean checkOrigin(String originHeaderValue) { + return true; + } + + @Override + public void modifyHandshake(ServerEndpointConfig sec, + HandshakeRequest request, HandshakeResponse response) { + // NO-OP + } + +} diff --git a/src/java/nginx/unit/websocket/server/LocalStrings.properties b/src/java/nginx/unit/websocket/server/LocalStrings.properties new file mode 100644 index 00000000..5bc12501 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/LocalStrings.properties @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +serverContainer.addNotAllowed=No further Endpoints may be registered once an attempt has been made to use one of the previously registered endpoints +serverContainer.configuratorFail=Failed to create configurator of type [{0}] for POJO of type [{1}] +serverContainer.duplicatePaths=Multiple Endpoints may not be deployed to the same path [{0}] : existing endpoint was [{1}] and new endpoint is [{2}] +serverContainer.encoderFail=Unable to create encoder of type [{0}] +serverContainer.endpointDeploy=Endpoint class [{0}] deploying to path [{1}] in ServletContext [{2}] +serverContainer.missingAnnotation=Cannot deploy POJO class [{0}] as it is not annotated with @ServerEndpoint +serverContainer.missingEndpoint=An Endpoint instance has been request for path [{0}] but no matching Endpoint class was found +serverContainer.pojoDeploy=POJO class [{0}] deploying to path [{1}] in ServletContext [{2}] +serverContainer.servletContextMismatch=Attempted to register a POJO annotated for WebSocket at path [{0}] in the ServletContext with context path [{1}] when the WebSocket ServerContainer is allocated to the ServletContext with context path [{2}] +serverContainer.servletContextMissing=No ServletContext was specified + +upgradeUtil.incompatibleRsv=Extensions were specified that have incompatible RSV bit usage + +uriTemplate.duplicateParameter=The parameter [{0}] appears more than once in the path which is not permitted +uriTemplate.emptySegment=The path [{0}] contains one or more empty segments which are is not permitted +uriTemplate.invalidPath=The path [{0}] is not valid. +uriTemplate.invalidSegment=The segment [{0}] is not valid in the provided path [{1}] + +wsFrameServer.bytesRead=Read [{0}] bytes into input buffer ready for processing +wsFrameServer.illegalReadState=Unexpected read state [{0}] +wsFrameServer.onDataAvailable=Method entry + +wsHttpUpgradeHandler.closeOnError=Closing WebSocket connection due to an error +wsHttpUpgradeHandler.destroyFailed=Failed to close WebConnection while destroying the WebSocket HttpUpgradeHandler +wsHttpUpgradeHandler.noPreInit=The preInit() method must be called to configure the WebSocket HttpUpgradeHandler before the container calls init(). Usually, this means the Servlet that created the WsHttpUpgradeHandler instance should also call preInit() +wsHttpUpgradeHandler.serverStop=The server is stopping + +wsRemoteEndpointServer.closeFailed=Failed to close the ServletOutputStream connection cleanly diff --git a/src/java/nginx/unit/websocket/server/UpgradeUtil.java b/src/java/nginx/unit/websocket/server/UpgradeUtil.java new file mode 100644 index 00000000..162f01c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/UpgradeUtil.java @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.server.ServerEndpointConfig; + +import nginx.unit.Request; + +import org.apache.tomcat.util.codec.binary.Base64; +import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.security.ConcurrentMessageDigest; +import nginx.unit.websocket.Constants; +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.TransformationFactory; +import nginx.unit.websocket.Util; +import nginx.unit.websocket.WsHandshakeResponse; +import nginx.unit.websocket.pojo.PojoEndpointServer; + +public class UpgradeUtil { + + private static final StringManager sm = + StringManager.getManager(UpgradeUtil.class.getPackage().getName()); + private static final byte[] WS_ACCEPT = + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes( + StandardCharsets.ISO_8859_1); + + private UpgradeUtil() { + // Utility class. Hide default constructor. + } + + /** + * Checks to see if this is an HTTP request that includes a valid upgrade + * request to web socket. + * <p> + * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java + * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC + * 6455 section 4.1 requires that a WebSocket Upgrade uses GET. + * @param request The request to check if it is an HTTP upgrade request for + * a WebSocket connection + * @param response The response associated with the request + * @return <code>true</code> if the request includes a HTTP Upgrade request + * for the WebSocket protocol, otherwise <code>false</code> + */ + public static boolean isWebSocketUpgradeRequest(ServletRequest request, + ServletResponse response) { + + Request r = (Request) request.getAttribute(Request.BARE); + + return ((request instanceof HttpServletRequest) && + (response instanceof HttpServletResponse) && + (r != null) && + (r.isUpgrade())); + } + + + public static void doUpgrade(WsServerContainer sc, HttpServletRequest req, + HttpServletResponse resp, ServerEndpointConfig sec, + Map<String,String> pathParams) + throws ServletException, IOException { + + + // Origin check + String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME); + + if (!sec.getConfigurator().checkOrigin(origin)) { + resp.sendError(HttpServletResponse.SC_FORBIDDEN); + return; + } + // Sub-protocols + List<String> subProtocols = getTokensFromHeader(req, + Constants.WS_PROTOCOL_HEADER_NAME); + String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol( + sec.getSubprotocols(), subProtocols); + + // Extensions + // Should normally only be one header but handle the case of multiple + // headers + List<Extension> extensionsRequested = new ArrayList<>(); + Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME); + while (extHeaders.hasMoreElements()) { + Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement()); + } + + // Negotiation phase 1. By default this simply filters out the + // extensions that the server does not support but applications could + // use a custom configurator to do more than this. + List<Extension> installedExtensions = null; + if (sec.getExtensions().size() == 0) { + installedExtensions = Constants.INSTALLED_EXTENSIONS; + } else { + installedExtensions = new ArrayList<>(); + installedExtensions.addAll(sec.getExtensions()); + installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS); + } + List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions( + installedExtensions, extensionsRequested); + + // Negotiation phase 2. Create the Transformations that will be applied + // to this connection. Note than an extension may be dropped at this + // point if the client has requested a configuration that the server is + // unable to support. + List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1); + + List<Extension> negotiatedExtensionsPhase2; + if (transformations.isEmpty()) { + negotiatedExtensionsPhase2 = Collections.emptyList(); + } else { + negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size()); + for (Transformation t : transformations) { + negotiatedExtensionsPhase2.add(t.getExtensionResponse()); + } + } + + WsHttpUpgradeHandler wsHandler = + req.upgrade(WsHttpUpgradeHandler.class); + + WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams); + WsHandshakeResponse wsResponse = new WsHandshakeResponse(); + WsPerSessionServerEndpointConfig perSessionServerEndpointConfig = + new WsPerSessionServerEndpointConfig(sec); + sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig, + wsRequest, wsResponse); + //wsRequest.finished(); + + // Add any additional headers + for (Entry<String,List<String>> entry : + wsResponse.getHeaders().entrySet()) { + for (String headerValue: entry.getValue()) { + resp.addHeader(entry.getKey(), headerValue); + } + } + + Endpoint ep; + try { + Class<?> clazz = sec.getEndpointClass(); + if (Endpoint.class.isAssignableFrom(clazz)) { + ep = (Endpoint) sec.getConfigurator().getEndpointInstance( + clazz); + } else { + ep = new PojoEndpointServer(); + // Need to make path params available to POJO + perSessionServerEndpointConfig.getUserProperties().put( + nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams); + } + } catch (InstantiationException e) { + throw new ServletException(e); + } + + wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest, + negotiatedExtensionsPhase2, subProtocol, null, pathParams, + req.isSecure()); + + wsHandler.init(null); + } + + + private static List<Transformation> createTransformations( + List<Extension> negotiatedExtensions) { + + TransformationFactory factory = TransformationFactory.getInstance(); + + LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences = + new LinkedHashMap<>(); + + // Result will likely be smaller than this + List<Transformation> result = new ArrayList<>(negotiatedExtensions.size()); + + for (Extension extension : negotiatedExtensions) { + List<List<Extension.Parameter>> preferences = + extensionPreferences.get(extension.getName()); + + if (preferences == null) { + preferences = new ArrayList<>(); + extensionPreferences.put(extension.getName(), preferences); + } + + preferences.add(extension.getParameters()); + } + + for (Map.Entry<String,List<List<Extension.Parameter>>> entry : + extensionPreferences.entrySet()) { + Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true); + if (transformation != null) { + result.add(transformation); + } + } + return result; + } + + + private static void append(StringBuilder sb, Extension extension) { + if (extension == null || extension.getName() == null || extension.getName().length() == 0) { + return; + } + + sb.append(extension.getName()); + + for (Extension.Parameter p : extension.getParameters()) { + sb.append(';'); + sb.append(p.getName()); + if (p.getValue() != null) { + sb.append('='); + sb.append(p.getValue()); + } + } + } + + + /* + * This only works for tokens. Quoted strings need more sophisticated + * parsing. + */ + private static boolean headerContainsToken(HttpServletRequest req, + String headerName, String target) { + Enumeration<String> headers = req.getHeaders(headerName); + while (headers.hasMoreElements()) { + String header = headers.nextElement(); + String[] tokens = header.split(","); + for (String token : tokens) { + if (target.equalsIgnoreCase(token.trim())) { + return true; + } + } + } + return false; + } + + + /* + * This only works for tokens. Quoted strings need more sophisticated + * parsing. + */ + private static List<String> getTokensFromHeader(HttpServletRequest req, + String headerName) { + List<String> result = new ArrayList<>(); + Enumeration<String> headers = req.getHeaders(headerName); + while (headers.hasMoreElements()) { + String header = headers.nextElement(); + String[] tokens = header.split(","); + for (String token : tokens) { + result.add(token.trim()); + } + } + return result; + } + + + private static String getWebSocketAccept(String key) { + byte[] digest = ConcurrentMessageDigest.digestSHA1( + key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT); + return Base64.encodeBase64String(digest); + } +} diff --git a/src/java/nginx/unit/websocket/server/UriTemplate.java b/src/java/nginx/unit/websocket/server/UriTemplate.java new file mode 100644 index 00000000..7877fac9 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/UriTemplate.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.websocket.DeploymentException; + +import org.apache.tomcat.util.res.StringManager; + +/** + * Extracts path parameters from URIs used to create web socket connections + * using the URI template defined for the associated Endpoint. + */ +public class UriTemplate { + + private static final StringManager sm = StringManager.getManager(UriTemplate.class); + + private final String normalized; + private final List<Segment> segments = new ArrayList<>(); + private final boolean hasParameters; + + + public UriTemplate(String path) throws DeploymentException { + + if (path == null || path.length() ==0 || !path.startsWith("/")) { + throw new DeploymentException( + sm.getString("uriTemplate.invalidPath", path)); + } + + StringBuilder normalized = new StringBuilder(path.length()); + Set<String> paramNames = new HashSet<>(); + + // Include empty segments. + String[] segments = path.split("/", -1); + int paramCount = 0; + int segmentCount = 0; + + for (int i = 0; i < segments.length; i++) { + String segment = segments[i]; + if (segment.length() == 0) { + if (i == 0 || (i == segments.length - 1 && paramCount == 0)) { + // Ignore the first empty segment as the path must always + // start with '/' + // Ending with a '/' is also OK for instances used for + // matches but not for parameterised templates. + continue; + } else { + // As per EG discussion, all other empty segments are + // invalid + throw new IllegalArgumentException(sm.getString( + "uriTemplate.emptySegment", path)); + } + } + normalized.append('/'); + int index = -1; + if (segment.startsWith("{") && segment.endsWith("}")) { + index = segmentCount; + segment = segment.substring(1, segment.length() - 1); + normalized.append('{'); + normalized.append(paramCount++); + normalized.append('}'); + if (!paramNames.add(segment)) { + throw new IllegalArgumentException(sm.getString( + "uriTemplate.duplicateParameter", segment)); + } + } else { + if (segment.contains("{") || segment.contains("}")) { + throw new IllegalArgumentException(sm.getString( + "uriTemplate.invalidSegment", segment, path)); + } + normalized.append(segment); + } + this.segments.add(new Segment(index, segment)); + segmentCount++; + } + + this.normalized = normalized.toString(); + this.hasParameters = paramCount > 0; + } + + + public Map<String,String> match(UriTemplate candidate) { + + Map<String,String> result = new HashMap<>(); + + // Should not happen but for safety + if (candidate.getSegmentCount() != getSegmentCount()) { + return null; + } + + Iterator<Segment> candidateSegments = + candidate.getSegments().iterator(); + Iterator<Segment> targetSegments = segments.iterator(); + + while (candidateSegments.hasNext()) { + Segment candidateSegment = candidateSegments.next(); + Segment targetSegment = targetSegments.next(); + + if (targetSegment.getParameterIndex() == -1) { + // Not a parameter - values must match + if (!targetSegment.getValue().equals( + candidateSegment.getValue())) { + // Not a match. Stop here + return null; + } + } else { + // Parameter + result.put(targetSegment.getValue(), + candidateSegment.getValue()); + } + } + + return result; + } + + + public boolean hasParameters() { + return hasParameters; + } + + + public int getSegmentCount() { + return segments.size(); + } + + + public String getNormalizedPath() { + return normalized; + } + + + private List<Segment> getSegments() { + return segments; + } + + + private static class Segment { + private final int parameterIndex; + private final String value; + + public Segment(int parameterIndex, String value) { + this.parameterIndex = parameterIndex; + this.value = value; + } + + + public int getParameterIndex() { + return parameterIndex; + } + + + public String getValue() { + return value; + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsContextListener.java b/src/java/nginx/unit/websocket/server/WsContextListener.java new file mode 100644 index 00000000..07137856 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsContextListener.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import javax.servlet.ServletContext; +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; + +/** + * In normal usage, this {@link ServletContextListener} does not need to be + * explicitly configured as the {@link WsSci} performs all the necessary + * bootstrap and installs this listener in the {@link ServletContext}. If the + * {@link WsSci} is disabled, this listener must be added manually to every + * {@link ServletContext} that uses WebSocket to bootstrap the + * {@link WsServerContainer} correctly. + */ +public class WsContextListener implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + ServletContext sc = sce.getServletContext(); + // Don't trigger WebSocket initialization if a WebSocket Server + // Container is already present + if (sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE) == null) { + WsSci.init(sce.getServletContext(), false); + } + } + + @Override + public void contextDestroyed(ServletContextEvent sce) { + ServletContext sc = sce.getServletContext(); + Object obj = sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + if (obj instanceof WsServerContainer) { + ((WsServerContainer) obj).destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsFilter.java b/src/java/nginx/unit/websocket/server/WsFilter.java new file mode 100644 index 00000000..abea71fc --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsFilter.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.GenericFilter; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Handles the initial HTTP connection for WebSocket connections. + */ +public class WsFilter extends GenericFilter { + + private static final long serialVersionUID = 1L; + + private transient WsServerContainer sc; + + + @Override + public void init() throws ServletException { + sc = (WsServerContainer) getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + } + + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + + // This filter only needs to handle WebSocket upgrade requests + if (!sc.areEndpointsRegistered() || + !UpgradeUtil.isWebSocketUpgradeRequest(request, response)) { + chain.doFilter(request, response); + return; + } + + // HTTP request with an upgrade header for WebSocket present + HttpServletRequest req = (HttpServletRequest) request; + HttpServletResponse resp = (HttpServletResponse) response; + + // Check to see if this WebSocket implementation has a matching mapping + String path; + String pathInfo = req.getPathInfo(); + if (pathInfo == null) { + path = req.getServletPath(); + } else { + path = req.getServletPath() + pathInfo; + } + WsMappingResult mappingResult = sc.findMapping(path); + + if (mappingResult == null) { + // No endpoint registered for the requested path. Let the + // application handle it (it might redirect or forward for example) + chain.doFilter(request, response); + return; + } + + UpgradeUtil.doUpgrade(sc, req, resp, mappingResult.getConfig(), + mappingResult.getPathParams()); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java b/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java new file mode 100644 index 00000000..fa774302 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.net.URI; +import java.net.URISyntaxException; +import java.security.Principal; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.servlet.http.HttpServletRequest; +import javax.websocket.server.HandshakeRequest; + +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; +import org.apache.tomcat.util.res.StringManager; + +/** + * Represents the request that this session was opened under. + */ +public class WsHandshakeRequest implements HandshakeRequest { + + private static final StringManager sm = StringManager.getManager(WsHandshakeRequest.class); + + private final URI requestUri; + private final Map<String,List<String>> parameterMap; + private final String queryString; + private final Principal userPrincipal; + private final Map<String,List<String>> headers; + private final Object httpSession; + + private volatile HttpServletRequest request; + + + public WsHandshakeRequest(HttpServletRequest request, Map<String,String> pathParams) { + + this.request = request; + + queryString = request.getQueryString(); + userPrincipal = request.getUserPrincipal(); + httpSession = request.getSession(false); + requestUri = buildRequestUri(request); + + // ParameterMap + Map<String,String[]> originalParameters = request.getParameterMap(); + Map<String,List<String>> newParameters = + new HashMap<>(originalParameters.size()); + for (Entry<String,String[]> entry : originalParameters.entrySet()) { + newParameters.put(entry.getKey(), + Collections.unmodifiableList( + Arrays.asList(entry.getValue()))); + } + for (Entry<String,String> entry : pathParams.entrySet()) { + newParameters.put(entry.getKey(), + Collections.unmodifiableList( + Collections.singletonList(entry.getValue()))); + } + parameterMap = Collections.unmodifiableMap(newParameters); + + // Headers + Map<String,List<String>> newHeaders = new CaseInsensitiveKeyMap<>(); + + Enumeration<String> headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + + newHeaders.put(headerName, Collections.unmodifiableList( + Collections.list(request.getHeaders(headerName)))); + } + + headers = Collections.unmodifiableMap(newHeaders); + } + + @Override + public URI getRequestURI() { + return requestUri; + } + + @Override + public Map<String,List<String>> getParameterMap() { + return parameterMap; + } + + @Override + public String getQueryString() { + return queryString; + } + + @Override + public Principal getUserPrincipal() { + return userPrincipal; + } + + @Override + public Map<String,List<String>> getHeaders() { + return headers; + } + + @Override + public boolean isUserInRole(String role) { + if (request == null) { + throw new IllegalStateException(); + } + + return request.isUserInRole(role); + } + + @Override + public Object getHttpSession() { + return httpSession; + } + + /** + * Called when the HandshakeRequest is no longer required. Since an instance + * of this class retains a reference to the current HttpServletRequest that + * reference needs to be cleared as the HttpServletRequest may be reused. + * + * There is no reason for instances of this class to be accessed once the + * handshake has been completed. + */ + void finished() { + request = null; + } + + + /* + * See RequestUtil.getRequestURL() + */ + private static URI buildRequestUri(HttpServletRequest req) { + + StringBuffer uri = new StringBuffer(); + String scheme = req.getScheme(); + int port = req.getServerPort(); + if (port < 0) { + // Work around java.net.URL bug + port = 80; + } + + if ("http".equals(scheme)) { + uri.append("ws"); + } else if ("https".equals(scheme)) { + uri.append("wss"); + } else { + // Should never happen + throw new IllegalArgumentException( + sm.getString("wsHandshakeRequest.unknownScheme", scheme)); + } + + uri.append("://"); + uri.append(req.getServerName()); + + if ((scheme.equals("http") && (port != 80)) + || (scheme.equals("https") && (port != 443))) { + uri.append(':'); + uri.append(port); + } + + uri.append(req.getRequestURI()); + + if (req.getQueryString() != null) { + uri.append("?"); + uri.append(req.getQueryString()); + } + + try { + return new URI(uri.toString()); + } catch (URISyntaxException e) { + // Should never happen + throw new IllegalArgumentException( + sm.getString("wsHandshakeRequest.invalidUri", uri.toString()), e); + } + } + + public Object getAttribute(String name) + { + return request != null ? request.getAttribute(name) : null; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java b/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java new file mode 100644 index 00000000..cc39ab73 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpUpgradeHandler; +import javax.servlet.http.WebConnection; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.WsIOException; +import nginx.unit.websocket.WsSession; + +import nginx.unit.Request; + +/** + * Servlet 3.1 HTTP upgrade handler for WebSocket connections. + */ +public class WsHttpUpgradeHandler implements HttpUpgradeHandler { + + private final Log log = LogFactory.getLog(WsHttpUpgradeHandler.class); // must not be static + private static final StringManager sm = StringManager.getManager(WsHttpUpgradeHandler.class); + + private final ClassLoader applicationClassLoader; + + private Endpoint ep; + private EndpointConfig endpointConfig; + private WsServerContainer webSocketContainer; + private WsHandshakeRequest handshakeRequest; + private List<Extension> negotiatedExtensions; + private String subProtocol; + private Transformation transformation; + private Map<String,String> pathParameters; + private boolean secure; + private WebConnection connection; + private WsRemoteEndpointImplServer wsRemoteEndpointServer; + private WsSession wsSession; + + + public WsHttpUpgradeHandler() { + applicationClassLoader = Thread.currentThread().getContextClassLoader(); + } + + public void preInit(Endpoint ep, EndpointConfig endpointConfig, + WsServerContainer wsc, WsHandshakeRequest handshakeRequest, + List<Extension> negotiatedExtensionsPhase2, String subProtocol, + Transformation transformation, Map<String,String> pathParameters, + boolean secure) { + this.ep = ep; + this.endpointConfig = endpointConfig; + this.webSocketContainer = wsc; + this.handshakeRequest = handshakeRequest; + this.negotiatedExtensions = negotiatedExtensionsPhase2; + this.subProtocol = subProtocol; + this.transformation = transformation; + this.pathParameters = pathParameters; + this.secure = secure; + } + + + @Override + public void init(WebConnection connection) { + if (ep == null) { + throw new IllegalStateException( + sm.getString("wsHttpUpgradeHandler.noPreInit")); + } + + String httpSessionId = null; + Object session = handshakeRequest.getHttpSession(); + if (session != null ) { + httpSessionId = ((HttpSession) session).getId(); + } + + nginx.unit.Context.trace("UpgradeHandler.init(" + connection + ")"); + +/* + // Need to call onOpen using the web application's class loader + // Create the frame using the application's class loader so it can pick + // up application specific config from the ServerContainerImpl + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); +*/ + try { + Request r = (Request) handshakeRequest.getAttribute(Request.BARE); + + wsRemoteEndpointServer = new WsRemoteEndpointImplServer(webSocketContainer); + wsSession = new WsSession(ep, wsRemoteEndpointServer, + webSocketContainer, handshakeRequest.getRequestURI(), + handshakeRequest.getParameterMap(), + handshakeRequest.getQueryString(), + handshakeRequest.getUserPrincipal(), httpSessionId, + negotiatedExtensions, subProtocol, pathParameters, secure, + endpointConfig, r); + + ep.onOpen(wsSession, endpointConfig); + webSocketContainer.registerSession(ep, wsSession); + } catch (DeploymentException e) { + throw new IllegalArgumentException(e); +/* + } finally { + t.setContextClassLoader(cl); +*/ + } + } + + + + @Override + public void destroy() { + if (connection != null) { + try { + connection.close(); + } catch (Exception e) { + log.error(sm.getString("wsHttpUpgradeHandler.destroyFailed"), e); + } + } + } + + + private void onError(Throwable throwable) { + // Need to call onError using the web application's class loader + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + ep.onError(wsSession, throwable); + } finally { + t.setContextClassLoader(cl); + } + } + + + private void close(CloseReason cr) { + /* + * Any call to this method is a result of a problem reading from the + * client. At this point that state of the connection is unknown. + * Attempt to send a close frame to the client and then close the socket + * immediately. There is no point in waiting for a close frame from the + * client because there is no guarantee that we can recover from + * whatever messed up state the client put the connection into. + */ + wsSession.onClose(cr); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsMappingResult.java b/src/java/nginx/unit/websocket/server/WsMappingResult.java new file mode 100644 index 00000000..a7a4c022 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsMappingResult.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.Map; + +import javax.websocket.server.ServerEndpointConfig; + +class WsMappingResult { + + private final ServerEndpointConfig config; + private final Map<String,String> pathParams; + + + WsMappingResult(ServerEndpointConfig config, + Map<String,String> pathParams) { + this.config = config; + this.pathParams = pathParams; + } + + + ServerEndpointConfig getConfig() { + return config; + } + + + Map<String,String> getPathParams() { + return pathParams; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java b/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java new file mode 100644 index 00000000..2be050cb --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; +import javax.websocket.Extension; +import javax.websocket.server.ServerEndpointConfig; + +/** + * Wraps the provided {@link ServerEndpointConfig} and provides a per session + * view - the difference being that the map returned by {@link + * #getUserProperties()} is unique to this instance rather than shared with the + * wrapped {@link ServerEndpointConfig}. + */ +class WsPerSessionServerEndpointConfig implements ServerEndpointConfig { + + private final ServerEndpointConfig perEndpointConfig; + private final Map<String,Object> perSessionUserProperties = + new ConcurrentHashMap<>(); + + WsPerSessionServerEndpointConfig(ServerEndpointConfig perEndpointConfig) { + this.perEndpointConfig = perEndpointConfig; + perSessionUserProperties.putAll(perEndpointConfig.getUserProperties()); + } + + @Override + public List<Class<? extends Encoder>> getEncoders() { + return perEndpointConfig.getEncoders(); + } + + @Override + public List<Class<? extends Decoder>> getDecoders() { + return perEndpointConfig.getDecoders(); + } + + @Override + public Map<String,Object> getUserProperties() { + return perSessionUserProperties; + } + + @Override + public Class<?> getEndpointClass() { + return perEndpointConfig.getEndpointClass(); + } + + @Override + public String getPath() { + return perEndpointConfig.getPath(); + } + + @Override + public List<String> getSubprotocols() { + return perEndpointConfig.getSubprotocols(); + } + + @Override + public List<Extension> getExtensions() { + return perEndpointConfig.getExtensions(); + } + + @Override + public Configurator getConfigurator() { + return perEndpointConfig.getConfigurator(); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java b/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java new file mode 100644 index 00000000..6d10a3be --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.EOFException; +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; +import java.nio.channels.InterruptedByTimeoutException; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.WsRemoteEndpointImplBase; + +/** + * This is the server side {@link javax.websocket.RemoteEndpoint} implementation + * - i.e. what the server uses to send data to the client. + */ +public class WsRemoteEndpointImplServer extends WsRemoteEndpointImplBase { + + private static final StringManager sm = + StringManager.getManager(WsRemoteEndpointImplServer.class); + private final Log log = LogFactory.getLog(WsRemoteEndpointImplServer.class); // must not be static + + private volatile SendHandler handler = null; + private volatile ByteBuffer[] buffers = null; + + private volatile long timeoutExpiry = -1; + private volatile boolean close; + + public WsRemoteEndpointImplServer( + WsServerContainer serverContainer) { + } + + + @Override + protected final boolean isMasked() { + return false; + } + + @Override + protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... buffers) { + } + + @Override + protected void doClose() { + if (handler != null) { + // close() can be triggered by a wide range of scenarios. It is far + // simpler just to always use a dispatch than it is to try and track + // whether or not this method was called by the same thread that + // triggered the write + clearHandler(new EOFException(), true); + } + } + + + protected long getTimeoutExpiry() { + return timeoutExpiry; + } + + + /* + * Currently this is only called from the background thread so we could just + * call clearHandler() with useDispatch == false but the method parameter + * was added in case other callers started to use this method to make sure + * that those callers think through what the correct value of useDispatch is + * for them. + */ + protected void onTimeout(boolean useDispatch) { + if (handler != null) { + clearHandler(new SocketTimeoutException(), useDispatch); + } + close(); + } + + + @Override + protected void setTransformation(Transformation transformation) { + // Overridden purely so it is visible to other classes in this package + super.setTransformation(transformation); + } + + + /** + * + * @param t The throwable associated with any error that + * occurred + * @param useDispatch Should {@link SendHandler#onResult(SendResult)} be + * called from a new thread, keeping in mind the + * requirements of + * {@link javax.websocket.RemoteEndpoint.Async} + */ + private void clearHandler(Throwable t, boolean useDispatch) { + // Setting the result marks this (partial) message as + // complete which means the next one may be sent which + // could update the value of the handler. Therefore, keep a + // local copy before signalling the end of the (partial) + // message. + SendHandler sh = handler; + handler = null; + buffers = null; + if (sh != null) { + if (useDispatch) { + OnResultRunnable r = new OnResultRunnable(sh, t); + } else { + if (t == null) { + sh.onResult(new SendResult()); + } else { + sh.onResult(new SendResult(t)); + } + } + } + } + + + private static class OnResultRunnable implements Runnable { + + private final SendHandler sh; + private final Throwable t; + + private OnResultRunnable(SendHandler sh, Throwable t) { + this.sh = sh; + this.t = t; + } + + @Override + public void run() { + if (t == null) { + sh.onResult(new SendResult()); + } else { + sh.onResult(new SendResult(t)); + } + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsSci.java b/src/java/nginx/unit/websocket/server/WsSci.java new file mode 100644 index 00000000..cdecce27 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsSci.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.lang.reflect.Modifier; +import java.util.HashSet; +import java.util.Set; + +import javax.servlet.ServletContainerInitializer; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.annotation.HandlesTypes; +import javax.websocket.ContainerProvider; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.server.ServerApplicationConfig; +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; + +/** + * Registers an interest in any class that is annotated with + * {@link ServerEndpoint} so that Endpoint can be published via the WebSocket + * server. + */ +@HandlesTypes({ServerEndpoint.class, ServerApplicationConfig.class, + Endpoint.class}) +public class WsSci implements ServletContainerInitializer { + + @Override + public void onStartup(Set<Class<?>> clazzes, ServletContext ctx) + throws ServletException { + + WsServerContainer sc = init(ctx, true); + + if (clazzes == null || clazzes.size() == 0) { + return; + } + + // Group the discovered classes by type + Set<ServerApplicationConfig> serverApplicationConfigs = new HashSet<>(); + Set<Class<? extends Endpoint>> scannedEndpointClazzes = new HashSet<>(); + Set<Class<?>> scannedPojoEndpoints = new HashSet<>(); + + try { + // wsPackage is "javax.websocket." + String wsPackage = ContainerProvider.class.getName(); + wsPackage = wsPackage.substring(0, wsPackage.lastIndexOf('.') + 1); + for (Class<?> clazz : clazzes) { + int modifiers = clazz.getModifiers(); + if (!Modifier.isPublic(modifiers) || + Modifier.isAbstract(modifiers)) { + // Non-public or abstract - skip it. + continue; + } + // Protect against scanning the WebSocket API JARs + if (clazz.getName().startsWith(wsPackage)) { + continue; + } + if (ServerApplicationConfig.class.isAssignableFrom(clazz)) { + serverApplicationConfigs.add( + (ServerApplicationConfig) clazz.getConstructor().newInstance()); + } + if (Endpoint.class.isAssignableFrom(clazz)) { + @SuppressWarnings("unchecked") + Class<? extends Endpoint> endpoint = + (Class<? extends Endpoint>) clazz; + scannedEndpointClazzes.add(endpoint); + } + if (clazz.isAnnotationPresent(ServerEndpoint.class)) { + scannedPojoEndpoints.add(clazz); + } + } + } catch (ReflectiveOperationException e) { + throw new ServletException(e); + } + + // Filter the results + Set<ServerEndpointConfig> filteredEndpointConfigs = new HashSet<>(); + Set<Class<?>> filteredPojoEndpoints = new HashSet<>(); + + if (serverApplicationConfigs.isEmpty()) { + filteredPojoEndpoints.addAll(scannedPojoEndpoints); + } else { + for (ServerApplicationConfig config : serverApplicationConfigs) { + Set<ServerEndpointConfig> configFilteredEndpoints = + config.getEndpointConfigs(scannedEndpointClazzes); + if (configFilteredEndpoints != null) { + filteredEndpointConfigs.addAll(configFilteredEndpoints); + } + Set<Class<?>> configFilteredPojos = + config.getAnnotatedEndpointClasses( + scannedPojoEndpoints); + if (configFilteredPojos != null) { + filteredPojoEndpoints.addAll(configFilteredPojos); + } + } + } + + try { + // Deploy endpoints + for (ServerEndpointConfig config : filteredEndpointConfigs) { + sc.addEndpoint(config); + } + // Deploy POJOs + for (Class<?> clazz : filteredPojoEndpoints) { + sc.addEndpoint(clazz); + } + } catch (DeploymentException e) { + throw new ServletException(e); + } + } + + + static WsServerContainer init(ServletContext servletContext, + boolean initBySciMechanism) { + + WsServerContainer sc = new WsServerContainer(servletContext); + + servletContext.setAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE, sc); + + servletContext.addListener(new WsSessionListener(sc)); + // Can't register the ContextListener again if the ContextListener is + // calling this method + if (initBySciMechanism) { + servletContext.addListener(new WsContextListener()); + } + + return sc; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsServerContainer.java b/src/java/nginx/unit/websocket/server/WsServerContainer.java new file mode 100644 index 00000000..069fc54f --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsServerContainer.java @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.EnumSet; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; +import java.util.concurrent.ConcurrentHashMap; + +import javax.servlet.DispatcherType; +import javax.servlet.FilterRegistration; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Encoder; +import javax.websocket.Endpoint; +import javax.websocket.server.ServerContainer; +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; +import javax.websocket.server.ServerEndpointConfig.Configurator; + +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.WsSession; +import nginx.unit.websocket.WsWebSocketContainer; +import nginx.unit.websocket.pojo.PojoMethodMapping; + +/** + * Provides a per class loader (i.e. per web application) instance of a + * ServerContainer. Web application wide defaults may be configured by setting + * the following servlet context initialisation parameters to the desired + * values. + * <ul> + * <li>{@link Constants#BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li> + * <li>{@link Constants#TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li> + * </ul> + */ +public class WsServerContainer extends WsWebSocketContainer + implements ServerContainer { + + private static final StringManager sm = StringManager.getManager(WsServerContainer.class); + + private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED = + new CloseReason(CloseCodes.VIOLATED_POLICY, + "This connection was established under an authenticated " + + "HTTP session that has ended."); + + private final ServletContext servletContext; + private final Map<String,ServerEndpointConfig> configExactMatchMap = + new ConcurrentHashMap<>(); + private final Map<Integer,SortedSet<TemplatePathMatch>> configTemplateMatchMap = + new ConcurrentHashMap<>(); + private volatile boolean enforceNoAddAfterHandshake = + nginx.unit.websocket.Constants.STRICT_SPEC_COMPLIANCE; + private volatile boolean addAllowed = true; + private final Map<String,Set<WsSession>> authenticatedSessions = new ConcurrentHashMap<>(); + private volatile boolean endpointsRegistered = false; + + WsServerContainer(ServletContext servletContext) { + + this.servletContext = servletContext; + setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName())); + + // Configure servlet context wide defaults + String value = servletContext.getInitParameter( + Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); + if (value != null) { + setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value)); + } + + value = servletContext.getInitParameter( + Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); + if (value != null) { + setDefaultMaxTextMessageBufferSize(Integer.parseInt(value)); + } + + value = servletContext.getInitParameter( + Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM); + if (value != null) { + setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value)); + } + + FilterRegistration.Dynamic fr = servletContext.addFilter( + "Tomcat WebSocket (JSR356) Filter", new WsFilter()); + fr.setAsyncSupported(true); + + EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST, + DispatcherType.FORWARD); + + fr.addMappingForUrlPatterns(types, true, "/*"); + } + + + /** + * Published the provided endpoint implementation at the specified path with + * the specified configuration. {@link #WsServerContainer(ServletContext)} + * must be called before calling this method. + * + * @param sec The configuration to use when creating endpoint instances + * @throws DeploymentException if the endpoint cannot be published as + * requested + */ + @Override + public void addEndpoint(ServerEndpointConfig sec) + throws DeploymentException { + + if (enforceNoAddAfterHandshake && !addAllowed) { + throw new DeploymentException( + sm.getString("serverContainer.addNotAllowed")); + } + + if (servletContext == null) { + throw new DeploymentException( + sm.getString("serverContainer.servletContextMissing")); + } + String path = sec.getPath(); + + // Add method mapping to user properties + PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(), + sec.getDecoders(), path); + if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null + || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) { + sec.getUserProperties().put(nginx.unit.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY, + methodMapping); + } + + UriTemplate uriTemplate = new UriTemplate(path); + if (uriTemplate.hasParameters()) { + Integer key = Integer.valueOf(uriTemplate.getSegmentCount()); + SortedSet<TemplatePathMatch> templateMatches = + configTemplateMatchMap.get(key); + if (templateMatches == null) { + // Ensure that if concurrent threads execute this block they + // both end up using the same TreeSet instance + templateMatches = new TreeSet<>( + TemplatePathMatchComparator.getInstance()); + configTemplateMatchMap.putIfAbsent(key, templateMatches); + templateMatches = configTemplateMatchMap.get(key); + } + if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) { + // Duplicate uriTemplate; + throw new DeploymentException( + sm.getString("serverContainer.duplicatePaths", path, + sec.getEndpointClass(), + sec.getEndpointClass())); + } + } else { + // Exact match + ServerEndpointConfig old = configExactMatchMap.put(path, sec); + if (old != null) { + // Duplicate path mappings + throw new DeploymentException( + sm.getString("serverContainer.duplicatePaths", path, + old.getEndpointClass(), + sec.getEndpointClass())); + } + } + + endpointsRegistered = true; + } + + + /** + * Provides the equivalent of {@link #addEndpoint(ServerEndpointConfig)} + * for publishing plain old java objects (POJOs) that have been annotated as + * WebSocket endpoints. + * + * @param pojo The annotated POJO + */ + @Override + public void addEndpoint(Class<?> pojo) throws DeploymentException { + + ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class); + if (annotation == null) { + throw new DeploymentException( + sm.getString("serverContainer.missingAnnotation", + pojo.getName())); + } + String path = annotation.value(); + + // Validate encoders + validateEncoders(annotation.encoders()); + + // ServerEndpointConfig + ServerEndpointConfig sec; + Class<? extends Configurator> configuratorClazz = + annotation.configurator(); + Configurator configurator = null; + if (!configuratorClazz.equals(Configurator.class)) { + try { + configurator = annotation.configurator().getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "serverContainer.configuratorFail", + annotation.configurator().getName(), + pojo.getClass().getName()), e); + } + } + if (configurator == null) { + configurator = new nginx.unit.websocket.server.DefaultServerEndpointConfigurator(); + } + sec = ServerEndpointConfig.Builder.create(pojo, path). + decoders(Arrays.asList(annotation.decoders())). + encoders(Arrays.asList(annotation.encoders())). + subprotocols(Arrays.asList(annotation.subprotocols())). + configurator(configurator). + build(); + + addEndpoint(sec); + } + + + boolean areEndpointsRegistered() { + return endpointsRegistered; + } + + + /** + * Until the WebSocket specification provides such a mechanism, this Tomcat + * proprietary method is provided to enable applications to programmatically + * determine whether or not to upgrade an individual request to WebSocket. + * <p> + * Note: This method is not used by Tomcat but is used directly by + * third-party code and must not be removed. + * + * @param request The request object to be upgraded + * @param response The response object to be populated with the result of + * the upgrade + * @param sec The server endpoint to use to process the upgrade request + * @param pathParams The path parameters associated with the upgrade request + * + * @throws ServletException If a configuration error prevents the upgrade + * from taking place + * @throws IOException If an I/O error occurs during the upgrade process + */ + public void doUpgrade(HttpServletRequest request, + HttpServletResponse response, ServerEndpointConfig sec, + Map<String,String> pathParams) + throws ServletException, IOException { + UpgradeUtil.doUpgrade(this, request, response, sec, pathParams); + } + + + public WsMappingResult findMapping(String path) { + + // Prevent registering additional endpoints once the first attempt has + // been made to use one + if (addAllowed) { + addAllowed = false; + } + + // Check an exact match. Simple case as there are no templates. + ServerEndpointConfig sec = configExactMatchMap.get(path); + if (sec != null) { + return new WsMappingResult(sec, Collections.<String, String>emptyMap()); + } + + // No exact match. Need to look for template matches. + UriTemplate pathUriTemplate = null; + try { + pathUriTemplate = new UriTemplate(path); + } catch (DeploymentException e) { + // Path is not valid so can't be matched to a WebSocketEndpoint + return null; + } + + // Number of segments has to match + Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount()); + SortedSet<TemplatePathMatch> templateMatches = + configTemplateMatchMap.get(key); + + if (templateMatches == null) { + // No templates with an equal number of segments so there will be + // no matches + return null; + } + + // List is in alphabetical order of normalised templates. + // Correct match is the first one that matches. + Map<String,String> pathParams = null; + for (TemplatePathMatch templateMatch : templateMatches) { + pathParams = templateMatch.getUriTemplate().match(pathUriTemplate); + if (pathParams != null) { + sec = templateMatch.getConfig(); + break; + } + } + + if (sec == null) { + // No match + return null; + } + + return new WsMappingResult(sec, pathParams); + } + + + + public boolean isEnforceNoAddAfterHandshake() { + return enforceNoAddAfterHandshake; + } + + + public void setEnforceNoAddAfterHandshake( + boolean enforceNoAddAfterHandshake) { + this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake; + } + + + /** + * {@inheritDoc} + * + * Overridden to make it visible to other classes in this package. + */ + @Override + protected void registerSession(Endpoint endpoint, WsSession wsSession) { + super.registerSession(endpoint, wsSession); + if (wsSession.isOpen() && + wsSession.getUserPrincipal() != null && + wsSession.getHttpSessionId() != null) { + registerAuthenticatedSession(wsSession, + wsSession.getHttpSessionId()); + } + } + + + /** + * {@inheritDoc} + * + * Overridden to make it visible to other classes in this package. + */ + @Override + protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { + if (wsSession.getUserPrincipal() != null && + wsSession.getHttpSessionId() != null) { + unregisterAuthenticatedSession(wsSession, + wsSession.getHttpSessionId()); + } + super.unregisterSession(endpoint, wsSession); + } + + + private void registerAuthenticatedSession(WsSession wsSession, + String httpSessionId) { + Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId); + if (wsSessions == null) { + wsSessions = Collections.newSetFromMap( + new ConcurrentHashMap<WsSession,Boolean>()); + authenticatedSessions.putIfAbsent(httpSessionId, wsSessions); + wsSessions = authenticatedSessions.get(httpSessionId); + } + wsSessions.add(wsSession); + } + + + private void unregisterAuthenticatedSession(WsSession wsSession, + String httpSessionId) { + Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId); + // wsSessions will be null if the HTTP session has ended + if (wsSessions != null) { + wsSessions.remove(wsSession); + } + } + + + public void closeAuthenticatedSession(String httpSessionId) { + Set<WsSession> wsSessions = authenticatedSessions.remove(httpSessionId); + + if (wsSessions != null && !wsSessions.isEmpty()) { + for (WsSession wsSession : wsSessions) { + try { + wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED); + } catch (IOException e) { + // Any IOExceptions during close will have been caught and the + // onError method called. + } + } + } + } + + + private static void validateEncoders(Class<? extends Encoder>[] encoders) + throws DeploymentException { + + for (Class<? extends Encoder> encoder : encoders) { + // Need to instantiate decoder to ensure it is valid and that + // deployment can be failed if it is not + @SuppressWarnings("unused") + Encoder instance; + try { + encoder.getConstructor().newInstance(); + } catch(ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "serverContainer.encoderFail", encoder.getName()), e); + } + } + } + + + private static class TemplatePathMatch { + private final ServerEndpointConfig config; + private final UriTemplate uriTemplate; + + public TemplatePathMatch(ServerEndpointConfig config, + UriTemplate uriTemplate) { + this.config = config; + this.uriTemplate = uriTemplate; + } + + + public ServerEndpointConfig getConfig() { + return config; + } + + + public UriTemplate getUriTemplate() { + return uriTemplate; + } + } + + + /** + * This Comparator implementation is thread-safe so only create a single + * instance. + */ + private static class TemplatePathMatchComparator + implements Comparator<TemplatePathMatch> { + + private static final TemplatePathMatchComparator INSTANCE = + new TemplatePathMatchComparator(); + + public static TemplatePathMatchComparator getInstance() { + return INSTANCE; + } + + private TemplatePathMatchComparator() { + // Hide default constructor + } + + @Override + public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) { + return tpm1.getUriTemplate().getNormalizedPath().compareTo( + tpm2.getUriTemplate().getNormalizedPath()); + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsSessionListener.java b/src/java/nginx/unit/websocket/server/WsSessionListener.java new file mode 100644 index 00000000..fc2bc9c5 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsSessionListener.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import javax.servlet.http.HttpSessionEvent; +import javax.servlet.http.HttpSessionListener; + +public class WsSessionListener implements HttpSessionListener{ + + private final WsServerContainer wsServerContainer; + + + public WsSessionListener(WsServerContainer wsServerContainer) { + this.wsServerContainer = wsServerContainer; + } + + + @Override + public void sessionDestroyed(HttpSessionEvent se) { + wsServerContainer.closeAuthenticatedSession(se.getSession().getId()); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsWriteTimeout.java b/src/java/nginx/unit/websocket/server/WsWriteTimeout.java new file mode 100644 index 00000000..2dfc4ab2 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsWriteTimeout.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.Comparator; +import java.util.Set; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.atomic.AtomicInteger; + +import nginx.unit.websocket.BackgroundProcess; +import nginx.unit.websocket.BackgroundProcessManager; + +/** + * Provides timeouts for asynchronous web socket writes. On the server side we + * only have access to {@link javax.servlet.ServletOutputStream} and + * {@link javax.servlet.ServletInputStream} so there is no way to set a timeout + * for writes to the client. + */ +public class WsWriteTimeout implements BackgroundProcess { + + private final Set<WsRemoteEndpointImplServer> endpoints = + new ConcurrentSkipListSet<>(new EndpointComparator()); + private final AtomicInteger count = new AtomicInteger(0); + private int backgroundProcessCount = 0; + private volatile int processPeriod = 1; + + @Override + public void backgroundProcess() { + // This method gets called once a second. + backgroundProcessCount ++; + + if (backgroundProcessCount >= processPeriod) { + backgroundProcessCount = 0; + + long now = System.currentTimeMillis(); + for (WsRemoteEndpointImplServer endpoint : endpoints) { + if (endpoint.getTimeoutExpiry() < now) { + // Background thread, not the thread that triggered the + // write so no need to use a dispatch + endpoint.onTimeout(false); + } else { + // Endpoints are ordered by timeout expiry so if this point + // is reached there is no need to check the remaining + // endpoints + break; + } + } + } + } + + + @Override + public void setProcessPeriod(int period) { + this.processPeriod = period; + } + + + /** + * {@inheritDoc} + * + * The default value is 1 which means asynchronous write timeouts are + * processed every 1 second. + */ + @Override + public int getProcessPeriod() { + return processPeriod; + } + + + public void register(WsRemoteEndpointImplServer endpoint) { + boolean result = endpoints.add(endpoint); + if (result) { + int newCount = count.incrementAndGet(); + if (newCount == 1) { + BackgroundProcessManager.getInstance().register(this); + } + } + } + + + public void unregister(WsRemoteEndpointImplServer endpoint) { + boolean result = endpoints.remove(endpoint); + if (result) { + int newCount = count.decrementAndGet(); + if (newCount == 0) { + BackgroundProcessManager.getInstance().unregister(this); + } + } + } + + + /** + * Note: this comparator imposes orderings that are inconsistent with equals + */ + private static class EndpointComparator implements + Comparator<WsRemoteEndpointImplServer> { + + @Override + public int compare(WsRemoteEndpointImplServer o1, + WsRemoteEndpointImplServer o2) { + + long t1 = o1.getTimeoutExpiry(); + long t2 = o2.getTimeoutExpiry(); + + if (t1 < t2) { + return -1; + } else if (t1 == t2) { + return 0; + } else { + return 1; + } + } + } +} diff --git a/src/java/nginx/unit/websocket/server/package-info.java b/src/java/nginx/unit/websocket/server/package-info.java new file mode 100644 index 00000000..87bc85a3 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Server-side specific implementation classes. These are in a separate package + * to make packaging a pure client JAR simpler. + */ +package nginx.unit.websocket.server; |