diff options
author | Max Romanov <max.romanov@nginx.com> | 2019-09-05 15:27:32 +0300 |
---|---|---|
committer | Max Romanov <max.romanov@nginx.com> | 2019-09-05 15:27:32 +0300 |
commit | 2b8cab1e2478547398ad9c2fe68e025c180cac54 (patch) | |
tree | d317fcf9ee52f0f8967116f531784ae533b0ae5a /src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java | |
parent | 3e23afb0d205e503f6cc7d852e34d07da9a5b7f7 (diff) | |
download | unit-2b8cab1e2478547398ad9c2fe68e025c180cac54.tar.gz unit-2b8cab1e2478547398ad9c2fe68e025c180cac54.tar.bz2 |
Java: introducing websocket support.
Diffstat (limited to 'src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java')
-rw-r--r-- | src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java | 578 |
1 files changed, 578 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java new file mode 100644 index 00000000..21654487 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +/** + * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot + * more testing before it can be considered robust. + */ +public class AsyncChannelWrapperSecure implements AsyncChannelWrapper { + + private final Log log = + LogFactory.getLog(AsyncChannelWrapperSecure.class); + private static final StringManager sm = + StringManager.getManager(AsyncChannelWrapperSecure.class); + + private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921); + private final AsynchronousSocketChannel socketChannel; + private final SSLEngine sslEngine; + private final ByteBuffer socketReadBuffer; + private final ByteBuffer socketWriteBuffer; + // One thread for read, one for write + private final ExecutorService executor = + Executors.newFixedThreadPool(2, new SecureIOThreadFactory()); + private AtomicBoolean writing = new AtomicBoolean(false); + private AtomicBoolean reading = new AtomicBoolean(false); + + public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, + SSLEngine sslEngine) { + this.socketChannel = socketChannel; + this.sslEngine = sslEngine; + + int socketBufferSize = sslEngine.getSession().getPacketBufferSize(); + socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize); + socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize); + } + + @Override + public Future<Integer> read(ByteBuffer dst) { + WrapperFuture<Integer,Void> future = new WrapperFuture<>(); + + if (!reading.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentRead")); + } + + ReadTask readTask = new ReadTask(dst, future); + + executor.execute(readTask); + + return future; + } + + @Override + public <B,A extends B> void read(ByteBuffer dst, A attachment, + CompletionHandler<Integer,B> handler) { + + WrapperFuture<Integer,B> future = + new WrapperFuture<>(handler, attachment); + + if (!reading.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentRead")); + } + + ReadTask readTask = new ReadTask(dst, future); + + executor.execute(readTask); + } + + @Override + public Future<Integer> write(ByteBuffer src) { + + WrapperFuture<Long,Void> inner = new WrapperFuture<>(); + + if (!writing.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentWrite")); + } + + WriteTask writeTask = + new WriteTask(new ByteBuffer[] {src}, 0, 1, inner); + + executor.execute(writeTask); + + Future<Integer> future = new LongToIntegerFuture(inner); + return future; + } + + @Override + public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler<Long,B> handler) { + + WrapperFuture<Long,B> future = + new WrapperFuture<>(handler, attachment); + + if (!writing.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentWrite")); + } + + WriteTask writeTask = new WriteTask(srcs, offset, length, future); + + executor.execute(writeTask); + } + + @Override + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + log.info(sm.getString("asyncChannelWrapperSecure.closeFail")); + } + executor.shutdownNow(); + } + + @Override + public Future<Void> handshake() throws SSLException { + + WrapperFuture<Void,Void> wFuture = new WrapperFuture<>(); + + Thread t = new WebSocketSslHandshakeThread(wFuture); + t.start(); + + return wFuture; + } + + + private class WriteTask implements Runnable { + + private final ByteBuffer[] srcs; + private final int offset; + private final int length; + private final WrapperFuture<Long,?> future; + + public WriteTask(ByteBuffer[] srcs, int offset, int length, + WrapperFuture<Long,?> future) { + this.srcs = srcs; + this.future = future; + this.offset = offset; + this.length = length; + } + + @Override + public void run() { + long written = 0; + + try { + for (int i = offset; i < offset + length; i++) { + ByteBuffer src = srcs[i]; + while (src.hasRemaining()) { + socketWriteBuffer.clear(); + + // Encrypt the data + SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer); + written += r.bytesConsumed(); + Status s = r.getStatus(); + + if (s == Status.OK || s == Status.BUFFER_OVERFLOW) { + // Need to write out the bytes and may need to read from + // the source again to empty it + } else { + // Status.BUFFER_UNDERFLOW - only happens on unwrap + // Status.CLOSED - unexpected + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.statusWrap")); + } + + // Check for tasks + if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable runnable = sslEngine.getDelegatedTask(); + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + } + + socketWriteBuffer.flip(); + + // Do the write + int toWrite = r.bytesProduced(); + while (toWrite > 0) { + Future<Integer> f = + socketChannel.write(socketWriteBuffer); + Integer socketWrite = f.get(); + toWrite -= socketWrite.intValue(); + } + } + } + + + if (writing.compareAndSet(true, false)) { + future.complete(Long.valueOf(written)); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateWrite"))); + } + } catch (Exception e) { + writing.set(false); + future.fail(e); + } + } + } + + + private class ReadTask implements Runnable { + + private final ByteBuffer dest; + private final WrapperFuture<Integer,?> future; + + public ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future) { + this.dest = dest; + this.future = future; + } + + @Override + public void run() { + int read = 0; + + boolean forceRead = false; + + try { + while (read == 0) { + socketReadBuffer.compact(); + + if (forceRead) { + forceRead = false; + Future<Integer> f = socketChannel.read(socketReadBuffer); + Integer socketRead = f.get(); + if (socketRead.intValue() == -1) { + throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof")); + } + } + + socketReadBuffer.flip(); + + if (socketReadBuffer.hasRemaining()) { + // Decrypt the data in the buffer + SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest); + read += r.bytesProduced(); + Status s = r.getStatus(); + + if (s == Status.OK) { + // Bytes available for reading and there may be + // sufficient data in the socketReadBuffer to + // support further reads without reading from the + // socket + } else if (s == Status.BUFFER_UNDERFLOW) { + // There is partial data in the socketReadBuffer + if (read == 0) { + // Need more data before the partial data can be + // processed and some output generated + forceRead = true; + } + // else return the data we have and deal with the + // partial data on the next read + } else if (s == Status.BUFFER_OVERFLOW) { + // Not enough space in the destination buffer to + // store all of the data. We could use a bytes read + // value of -bufferSizeRequired to signal the new + // buffer size required but an explicit exception is + // clearer. + if (reading.compareAndSet(true, false)) { + throw new ReadBufferOverflowException(sslEngine. + getSession().getApplicationBufferSize()); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateRead"))); + } + } else { + // Status.CLOSED - unexpected + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.statusUnwrap")); + } + + // Check for tasks + if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable runnable = sslEngine.getDelegatedTask(); + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + } + } else { + forceRead = true; + } + } + + + if (reading.compareAndSet(true, false)) { + future.complete(Integer.valueOf(read)); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateRead"))); + } + } catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException | + ExecutionException | InterruptedException e) { + reading.set(false); + future.fail(e); + } + } + } + + + private class WebSocketSslHandshakeThread extends Thread { + + private final WrapperFuture<Void,Void> hFuture; + + private HandshakeStatus handshakeStatus; + private Status resultStatus; + + public WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture) { + this.hFuture = hFuture; + } + + @Override + public void run() { + try { + sslEngine.beginHandshake(); + // So the first compact does the right thing + socketReadBuffer.position(socketReadBuffer.limit()); + + handshakeStatus = sslEngine.getHandshakeStatus(); + resultStatus = Status.OK; + + boolean handshaking = true; + + while(handshaking) { + switch (handshakeStatus) { + case NEED_WRAP: { + socketWriteBuffer.clear(); + SSLEngineResult r = + sslEngine.wrap(DUMMY, socketWriteBuffer); + checkResult(r, true); + socketWriteBuffer.flip(); + Future<Integer> fWrite = + socketChannel.write(socketWriteBuffer); + fWrite.get(); + break; + } + case NEED_UNWRAP: { + socketReadBuffer.compact(); + if (socketReadBuffer.position() == 0 || + resultStatus == Status.BUFFER_UNDERFLOW) { + Future<Integer> fRead = + socketChannel.read(socketReadBuffer); + fRead.get(); + } + socketReadBuffer.flip(); + SSLEngineResult r = + sslEngine.unwrap(socketReadBuffer, DUMMY); + checkResult(r, false); + break; + } + case NEED_TASK: { + Runnable r = null; + while ((r = sslEngine.getDelegatedTask()) != null) { + r.run(); + } + handshakeStatus = sslEngine.getHandshakeStatus(); + break; + } + case FINISHED: { + handshaking = false; + break; + } + case NOT_HANDSHAKING: { + throw new SSLException( + sm.getString("asyncChannelWrapperSecure.notHandshaking")); + } + } + } + } catch (Exception e) { + hFuture.fail(e); + return; + } + + hFuture.complete(null); + } + + private void checkResult(SSLEngineResult result, boolean wrap) + throws SSLException { + + handshakeStatus = result.getHandshakeStatus(); + resultStatus = result.getStatus(); + + if (resultStatus != Status.OK && + (wrap || resultStatus != Status.BUFFER_UNDERFLOW)) { + throw new SSLException( + sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus)); + } + if (wrap && result.bytesConsumed() != 0) { + throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap")); + } + if (!wrap && result.bytesProduced() != 0) { + throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap")); + } + } + } + + + private static class WrapperFuture<T,A> implements Future<T> { + + private final CompletionHandler<T,A> handler; + private final A attachment; + + private volatile T result = null; + private volatile Throwable throwable = null; + private CountDownLatch completionLatch = new CountDownLatch(1); + + public WrapperFuture() { + this(null, null); + } + + public WrapperFuture(CompletionHandler<T,A> handler, A attachment) { + this.handler = handler; + this.attachment = attachment; + } + + public void complete(T result) { + this.result = result; + completionLatch.countDown(); + if (handler != null) { + handler.completed(result, attachment); + } + } + + public void fail(Throwable t) { + throwable = t; + completionLatch.countDown(); + if (handler != null) { + handler.failed(throwable, attachment); + } + } + + @Override + public final boolean cancel(boolean mayInterruptIfRunning) { + // Could support cancellation by closing the connection + return false; + } + + @Override + public final boolean isCancelled() { + // Could support cancellation by closing the connection + return false; + } + + @Override + public final boolean isDone() { + return completionLatch.getCount() > 0; + } + + @Override + public T get() throws InterruptedException, ExecutionException { + completionLatch.await(); + if (throwable != null) { + throw new ExecutionException(throwable); + } + return result; + } + + @Override + public T get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + boolean latchResult = completionLatch.await(timeout, unit); + if (latchResult == false) { + throw new TimeoutException(); + } + if (throwable != null) { + throw new ExecutionException(throwable); + } + return result; + } + } + + private static final class LongToIntegerFuture implements Future<Integer> { + + private final Future<Long> wrapped; + + public LongToIntegerFuture(Future<Long> wrapped) { + this.wrapped = wrapped; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return wrapped.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return wrapped.isCancelled(); + } + + @Override + public boolean isDone() { + return wrapped.isDone(); + } + + @Override + public Integer get() throws InterruptedException, ExecutionException { + Long result = wrapped.get(); + if (result.longValue() > Integer.MAX_VALUE) { + throw new ExecutionException(sm.getString( + "asyncChannelWrapperSecure.tooBig", result), null); + } + return Integer.valueOf(result.intValue()); + } + + @Override + public Integer get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + Long result = wrapped.get(timeout, unit); + if (result.longValue() > Integer.MAX_VALUE) { + throw new ExecutionException(sm.getString( + "asyncChannelWrapperSecure.tooBig", result), null); + } + return Integer.valueOf(result.intValue()); + } + } + + + private static class SecureIOThreadFactory implements ThreadFactory { + + private AtomicInteger count = new AtomicInteger(0); + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet()); + // No need to set the context class loader. The threads will be + // cleaned up when the connection is closed. + t.setDaemon(true); + return t; + } + } +} |