summaryrefslogtreecommitdiffhomepage
path: root/src/java/nginx/unit/websocket/PerMessageDeflate.java
diff options
context:
space:
mode:
authorMax Romanov <max.romanov@nginx.com>2019-09-05 15:27:32 +0300
committerMax Romanov <max.romanov@nginx.com>2019-09-05 15:27:32 +0300
commit2b8cab1e2478547398ad9c2fe68e025c180cac54 (patch)
treed317fcf9ee52f0f8967116f531784ae533b0ae5a /src/java/nginx/unit/websocket/PerMessageDeflate.java
parent3e23afb0d205e503f6cc7d852e34d07da9a5b7f7 (diff)
downloadunit-2b8cab1e2478547398ad9c2fe68e025c180cac54.tar.gz
unit-2b8cab1e2478547398ad9c2fe68e025c180cac54.tar.bz2
Java: introducing websocket support.
Diffstat (limited to 'src/java/nginx/unit/websocket/PerMessageDeflate.java')
-rw-r--r--src/java/nginx/unit/websocket/PerMessageDeflate.java476
1 files changed, 476 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/PerMessageDeflate.java b/src/java/nginx/unit/websocket/PerMessageDeflate.java
new file mode 100644
index 00000000..88e0a0bc
--- /dev/null
+++ b/src/java/nginx/unit/websocket/PerMessageDeflate.java
@@ -0,0 +1,476 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.zip.DataFormatException;
+import java.util.zip.Deflater;
+import java.util.zip.Inflater;
+
+import javax.websocket.Extension;
+import javax.websocket.Extension.Parameter;
+import javax.websocket.SendHandler;
+
+import org.apache.tomcat.util.res.StringManager;
+
+public class PerMessageDeflate implements Transformation {
+
+ private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class);
+
+ private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover";
+ private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover";
+ private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits";
+ private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";
+
+ private static final int RSV_BITMASK = 0b100;
+ private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1};
+
+ public static final String NAME = "permessage-deflate";
+
+ private final boolean serverContextTakeover;
+ private final int serverMaxWindowBits;
+ private final boolean clientContextTakeover;
+ private final int clientMaxWindowBits;
+ private final boolean isServer;
+ private final Inflater inflater = new Inflater(true);
+ private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
+ private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1];
+
+ private volatile Transformation next;
+ private volatile boolean skipDecompression = false;
+ private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private volatile boolean firstCompressedFrameWritten = false;
+ // Flag to track if a message is completely empty
+ private volatile boolean emptyMessage = true;
+
+ static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) {
+ // Accept the first preference that the endpoint is able to support
+ for (List<Parameter> preference : preferences) {
+ boolean ok = true;
+ boolean serverContextTakeover = true;
+ int serverMaxWindowBits = -1;
+ boolean clientContextTakeover = true;
+ int clientMaxWindowBits = -1;
+
+ for (Parameter param : preference) {
+ if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
+ if (serverContextTakeover) {
+ serverContextTakeover = false;
+ } else {
+ // Duplicate definition
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.duplicateParameter",
+ SERVER_NO_CONTEXT_TAKEOVER ));
+ }
+ } else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
+ if (clientContextTakeover) {
+ clientContextTakeover = false;
+ } else {
+ // Duplicate definition
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.duplicateParameter",
+ CLIENT_NO_CONTEXT_TAKEOVER ));
+ }
+ } else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) {
+ if (serverMaxWindowBits == -1) {
+ serverMaxWindowBits = Integer.parseInt(param.getValue());
+ if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) {
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.invalidWindowSize",
+ SERVER_MAX_WINDOW_BITS,
+ Integer.valueOf(serverMaxWindowBits)));
+ }
+ // Java SE API (as of Java 8) does not expose the API to
+ // control the Window size. It is effectively hard-coded
+ // to 15
+ if (isServer && serverMaxWindowBits != 15) {
+ ok = false;
+ break;
+ // Note server window size is not an issue for the
+ // client since the client will assume 15 and if the
+ // server uses a smaller window everything will
+ // still work
+ }
+ } else {
+ // Duplicate definition
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.duplicateParameter",
+ SERVER_MAX_WINDOW_BITS ));
+ }
+ } else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) {
+ if (clientMaxWindowBits == -1) {
+ if (param.getValue() == null) {
+ // Hint to server that the client supports this
+ // option. Java SE API (as of Java 8) does not
+ // expose the API to control the Window size. It is
+ // effectively hard-coded to 15
+ clientMaxWindowBits = 15;
+ } else {
+ clientMaxWindowBits = Integer.parseInt(param.getValue());
+ if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) {
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.invalidWindowSize",
+ CLIENT_MAX_WINDOW_BITS,
+ Integer.valueOf(clientMaxWindowBits)));
+ }
+ }
+ // Java SE API (as of Java 8) does not expose the API to
+ // control the Window size. It is effectively hard-coded
+ // to 15
+ if (!isServer && clientMaxWindowBits != 15) {
+ ok = false;
+ break;
+ // Note client window size is not an issue for the
+ // server since the server will assume 15 and if the
+ // client uses a smaller window everything will
+ // still work
+ }
+ } else {
+ // Duplicate definition
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.duplicateParameter",
+ CLIENT_MAX_WINDOW_BITS ));
+ }
+ } else {
+ // Unknown parameter
+ throw new IllegalArgumentException(sm.getString(
+ "perMessageDeflate.unknownParameter", param.getName()));
+ }
+ }
+ if (ok) {
+ return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits,
+ clientContextTakeover, clientMaxWindowBits, isServer);
+ }
+ }
+ // Failed to negotiate agreeable terms
+ return null;
+ }
+
+
+ private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits,
+ boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) {
+ this.serverContextTakeover = serverContextTakeover;
+ this.serverMaxWindowBits = serverMaxWindowBits;
+ this.clientContextTakeover = clientContextTakeover;
+ this.clientMaxWindowBits = clientMaxWindowBits;
+ this.isServer = isServer;
+ }
+
+
+ @Override
+ public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)
+ throws IOException {
+ // Control frames are never compressed and may appear in the middle of
+ // a WebSocket method. Pass them straight through.
+ if (Util.isControl(opCode)) {
+ return next.getMoreData(opCode, fin, rsv, dest);
+ }
+
+ if (!Util.isContinuation(opCode)) {
+ // First frame in new message
+ skipDecompression = (rsv & RSV_BITMASK) == 0;
+ }
+
+ // Pass uncompressed frames straight through.
+ if (skipDecompression) {
+ return next.getMoreData(opCode, fin, rsv, dest);
+ }
+
+ int written;
+ boolean usedEomBytes = false;
+
+ while (dest.remaining() > 0) {
+ // Space available in destination. Try and fill it.
+ try {
+ written = inflater.inflate(
+ dest.array(), dest.arrayOffset() + dest.position(), dest.remaining());
+ } catch (DataFormatException e) {
+ throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e);
+ }
+ dest.position(dest.position() + written);
+
+ if (inflater.needsInput() && !usedEomBytes ) {
+ if (dest.hasRemaining()) {
+ readBuffer.clear();
+ TransformationResult nextResult =
+ next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer);
+ inflater.setInput(
+ readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position());
+ if (TransformationResult.UNDERFLOW.equals(nextResult)) {
+ return nextResult;
+ } else if (TransformationResult.END_OF_FRAME.equals(nextResult) &&
+ readBuffer.position() == 0) {
+ if (fin) {
+ inflater.setInput(EOM_BYTES);
+ usedEomBytes = true;
+ } else {
+ return TransformationResult.END_OF_FRAME;
+ }
+ }
+ }
+ } else if (written == 0) {
+ if (fin && (isServer && !clientContextTakeover ||
+ !isServer && !serverContextTakeover)) {
+ inflater.reset();
+ }
+ return TransformationResult.END_OF_FRAME;
+ }
+ }
+
+ return TransformationResult.OVERFLOW;
+ }
+
+
+ @Override
+ public boolean validateRsv(int rsv, byte opCode) {
+ if (Util.isControl(opCode)) {
+ if ((rsv & RSV_BITMASK) != 0) {
+ return false;
+ } else {
+ if (next == null) {
+ return true;
+ } else {
+ return next.validateRsv(rsv, opCode);
+ }
+ }
+ } else {
+ int rsvNext = rsv;
+ if ((rsv & RSV_BITMASK) != 0) {
+ rsvNext = rsv ^ RSV_BITMASK;
+ }
+ if (next == null) {
+ return true;
+ } else {
+ return next.validateRsv(rsvNext, opCode);
+ }
+ }
+ }
+
+
+ @Override
+ public Extension getExtensionResponse() {
+ Extension result = new WsExtension(NAME);
+
+ List<Extension.Parameter> params = result.getParameters();
+
+ if (!serverContextTakeover) {
+ params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null));
+ }
+ if (serverMaxWindowBits != -1) {
+ params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS,
+ Integer.toString(serverMaxWindowBits)));
+ }
+ if (!clientContextTakeover) {
+ params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null));
+ }
+ if (clientMaxWindowBits != -1) {
+ params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS,
+ Integer.toString(clientMaxWindowBits)));
+ }
+
+ return result;
+ }
+
+
+ @Override
+ public void setNext(Transformation t) {
+ if (next == null) {
+ this.next = t;
+ } else {
+ next.setNext(t);
+ }
+ }
+
+
+ @Override
+ public boolean validateRsvBits(int i) {
+ if ((i & RSV_BITMASK) != 0) {
+ return false;
+ }
+ if (next == null) {
+ return true;
+ } else {
+ return next.validateRsvBits(i | RSV_BITMASK);
+ }
+ }
+
+
+ @Override
+ public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) {
+ List<MessagePart> allCompressedParts = new ArrayList<>();
+
+ for (MessagePart uncompressedPart : uncompressedParts) {
+ byte opCode = uncompressedPart.getOpCode();
+ boolean emptyPart = uncompressedPart.getPayload().limit() == 0;
+ emptyMessage = emptyMessage && emptyPart;
+ if (Util.isControl(opCode)) {
+ // Control messages can appear in the middle of other messages
+ // and must not be compressed. Pass it straight through
+ allCompressedParts.add(uncompressedPart);
+ } else if (emptyMessage && uncompressedPart.isFin()) {
+ // Zero length messages can't be compressed so pass the
+ // final (empty) part straight through.
+ allCompressedParts.add(uncompressedPart);
+ } else {
+ List<MessagePart> compressedParts = new ArrayList<>();
+ ByteBuffer uncompressedPayload = uncompressedPart.getPayload();
+ SendHandler uncompressedIntermediateHandler =
+ uncompressedPart.getIntermediateHandler();
+
+ deflater.setInput(uncompressedPayload.array(),
+ uncompressedPayload.arrayOffset() + uncompressedPayload.position(),
+ uncompressedPayload.remaining());
+
+ int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH);
+ boolean deflateRequired = true;
+
+ while (deflateRequired) {
+ ByteBuffer compressedPayload = writeBuffer;
+
+ int written = deflater.deflate(compressedPayload.array(),
+ compressedPayload.arrayOffset() + compressedPayload.position(),
+ compressedPayload.remaining(), flush);
+ compressedPayload.position(compressedPayload.position() + written);
+
+ if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) {
+ // This message part has been fully processed by the
+ // deflater. Fire the send handler for this message part
+ // and move on to the next message part.
+ break;
+ }
+
+ // If this point is reached, a new compressed message part
+ // will be created...
+ MessagePart compressedPart;
+
+ // .. and a new writeBuffer will be required.
+ writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+
+ // Flip the compressed payload ready for writing
+ compressedPayload.flip();
+
+ boolean fin = uncompressedPart.isFin();
+ boolean full = compressedPayload.limit() == compressedPayload.capacity();
+ boolean needsInput = deflater.needsInput();
+ long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry();
+
+ if (fin && !full && needsInput) {
+ // End of compressed message. Drop EOM bytes and output.
+ compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length);
+ compressedPart = new MessagePart(true, getRsv(uncompressedPart),
+ opCode, compressedPayload, uncompressedIntermediateHandler,
+ uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
+ deflateRequired = false;
+ startNewMessage();
+ } else if (full && !needsInput) {
+ // Write buffer full and input message not fully read.
+ // Output and start new compressed part.
+ compressedPart = new MessagePart(false, getRsv(uncompressedPart),
+ opCode, compressedPayload, uncompressedIntermediateHandler,
+ uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
+ } else if (!fin && full && needsInput) {
+ // Write buffer full and input message not fully read.
+ // Output and get more data.
+ compressedPart = new MessagePart(false, getRsv(uncompressedPart),
+ opCode, compressedPayload, uncompressedIntermediateHandler,
+ uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
+ deflateRequired = false;
+ } else if (fin && full && needsInput) {
+ // Write buffer full. Input fully read. Deflater may be
+ // in one of four states:
+ // - output complete (just happened to align with end of
+ // buffer
+ // - in middle of EOM bytes
+ // - about to write EOM bytes
+ // - more data to write
+ int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH);
+ if (eomBufferWritten < EOM_BUFFER.length) {
+ // EOM has just been completed
+ compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten);
+ compressedPart = new MessagePart(true,
+ getRsv(uncompressedPart), opCode, compressedPayload,
+ uncompressedIntermediateHandler, uncompressedIntermediateHandler,
+ blockingWriteTimeoutExpiry);
+ deflateRequired = false;
+ startNewMessage();
+ } else {
+ // More data to write
+ // Copy bytes to new write buffer
+ writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten);
+ compressedPart = new MessagePart(false,
+ getRsv(uncompressedPart), opCode, compressedPayload,
+ uncompressedIntermediateHandler, uncompressedIntermediateHandler,
+ blockingWriteTimeoutExpiry);
+ }
+ } else {
+ throw new IllegalStateException("Should never happen");
+ }
+
+ // Add the newly created compressed part to the set of parts
+ // to pass on to the next transformation.
+ compressedParts.add(compressedPart);
+ }
+
+ SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler();
+ int size = compressedParts.size();
+ if (size > 0) {
+ compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler);
+ }
+
+ allCompressedParts.addAll(compressedParts);
+ }
+ }
+
+ if (next == null) {
+ return allCompressedParts;
+ } else {
+ return next.sendMessagePart(allCompressedParts);
+ }
+ }
+
+
+ private void startNewMessage() {
+ firstCompressedFrameWritten = false;
+ emptyMessage = true;
+ if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) {
+ deflater.reset();
+ }
+ }
+
+
+ private int getRsv(MessagePart uncompressedMessagePart) {
+ int result = uncompressedMessagePart.getRsv();
+ if (!firstCompressedFrameWritten) {
+ result += RSV_BITMASK;
+ firstCompressedFrameWritten = true;
+ }
+ return result;
+ }
+
+
+ @Override
+ public void close() {
+ // There will always be a next transformation
+ next.close();
+ inflater.end();
+ deflater.end();
+ }
+}