/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package nginx.unit.websocket;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import javax.websocket.Extension;
import javax.websocket.Extension.Parameter;
import javax.websocket.SendHandler;
import org.apache.tomcat.util.res.StringManager;
public class PerMessageDeflate implements Transformation {
private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class);
private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover";
private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover";
private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits";
private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits";
private static final int RSV_BITMASK = 0b100;
private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1};
public static final String NAME = "permessage-deflate";
private final boolean serverContextTakeover;
private final int serverMaxWindowBits;
private final boolean clientContextTakeover;
private final int clientMaxWindowBits;
private final boolean isServer;
private final Inflater inflater = new Inflater(true);
private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true);
private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1];
private volatile Transformation next;
private volatile boolean skipDecompression = false;
private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
private volatile boolean firstCompressedFrameWritten = false;
// Flag to track if a message is completely empty
private volatile boolean emptyMessage = true;
static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) {
// Accept the first preference that the endpoint is able to support
for (List<Parameter> preference : preferences) {
boolean ok = true;
boolean serverContextTakeover = true;
int serverMaxWindowBits = -1;
boolean clientContextTakeover = true;
int clientMaxWindowBits = -1;
for (Parameter param : preference) {
if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
if (serverContextTakeover) {
serverContextTakeover = false;
} else {
// Duplicate definition
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.duplicateParameter",
SERVER_NO_CONTEXT_TAKEOVER ));
}
} else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) {
if (clientContextTakeover) {
clientContextTakeover = false;
} else {
// Duplicate definition
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.duplicateParameter",
CLIENT_NO_CONTEXT_TAKEOVER ));
}
} else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) {
if (serverMaxWindowBits == -1) {
serverMaxWindowBits = Integer.parseInt(param.getValue());
if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) {
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.invalidWindowSize",
SERVER_MAX_WINDOW_BITS,
Integer.valueOf(serverMaxWindowBits)));
}
// Java SE API (as of Java 8) does not expose the API to
// control the Window size. It is effectively hard-coded
// to 15
if (isServer && serverMaxWindowBits != 15) {
ok = false;
break;
// Note server window size is not an issue for the
// client since the client will assume 15 and if the
// server uses a smaller window everything will
// still work
}
} else {
// Duplicate definition
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.duplicateParameter",
SERVER_MAX_WINDOW_BITS ));
}
} else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) {
if (clientMaxWindowBits == -1) {
if (param.getValue() == null) {
// Hint to server that the client supports this
// option. Java SE API (as of Java 8) does not
// expose the API to control the Window size. It is
// effectively hard-coded to 15
clientMaxWindowBits = 15;
} else {
clientMaxWindowBits = Integer.parseInt(param.getValue());
if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) {
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.invalidWindowSize",
CLIENT_MAX_WINDOW_BITS,
Integer.valueOf(clientMaxWindowBits)));
}
}
// Java SE API (as of Java 8) does not expose the API to
// control the Window size. It is effectively hard-coded
// to 15
if (!isServer && clientMaxWindowBits != 15) {
ok = false;
break;
// Note client window size is not an issue for the
// server since the server will assume 15 and if the
// client uses a smaller window everything will
// still work
}
} else {
// Duplicate definition
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.duplicateParameter",
CLIENT_MAX_WINDOW_BITS ));
}
} else {
// Unknown parameter
throw new IllegalArgumentException(sm.getString(
"perMessageDeflate.unknownParameter", param.getName()));
}
}
if (ok) {
return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits,
clientContextTakeover, clientMaxWindowBits, isServer);
}
}
// Failed to negotiate agreeable terms
return null;
}
private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits,
boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) {
this.serverContextTakeover = serverContextTakeover;
this.serverMaxWindowBits = serverMaxWindowBits;
this.clientContextTakeover = clientContextTakeover;
this.clientMaxWindowBits = clientMaxWindowBits;
this.isServer = isServer;
}
@Override
public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest)
throws IOException {
// Control frames are never compressed and may appear in the middle of
// a WebSocket method. Pass them straight through.
if (Util.isControl(opCode)) {
return next.getMoreData(opCode, fin, rsv, dest);
}
if (!Util.isContinuation(opCode)) {
// First frame in new message
skipDecompression = (rsv & RSV_BITMASK) == 0;
}
// Pass uncompressed frames straight through.
if (skipDecompression) {
return next.getMoreData(opCode, fin, rsv, dest);
}
int written;
boolean usedEomBytes = false;
while (dest.remaining() > 0) {
// Space available in destination. Try and fill it.
try {
written = inflater.inflate(
dest.array(), dest.arrayOffset() + dest.position(), dest.remaining());
} catch (DataFormatException e) {
throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e);
}
dest.position(dest.position() + written);
if (inflater.needsInput() && !usedEomBytes ) {
if (dest.hasRemaining()) {
readBuffer.clear();
TransformationResult nextResult =
next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer);
inflater.setInput(
readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position());
if (TransformationResult.UNDERFLOW.equals(nextResult)) {
return nextResult;
} else if (TransformationResult.END_OF_FRAME.equals(nextResult) &&
readBuffer.position() == 0) {
if (fin) {
inflater.setInput(EOM_BYTES);
usedEomBytes = true;
} else {
return TransformationResult.END_OF_FRAME;
}
}
}
} else if (written == 0) {
if (fin && (isServer && !clientContextTakeover ||
!isServer && !serverContextTakeover)) {
inflater.reset();
}
return TransformationResult.END_OF_FRAME;
}
}
return TransformationResult.OVERFLOW;
}
@Override
public boolean validateRsv(int rsv, byte opCode) {
if (Util.isControl(opCode)) {
if ((rsv & RSV_BITMASK) != 0) {
return false;
} else {
if (next == null) {
return true;
} else {
return next.validateRsv(rsv, opCode);
}
}
} else {
int rsvNext = rsv;
if ((rsv & RSV_BITMASK) != 0) {
rsvNext = rsv ^ RSV_BITMASK;
}
if (next == null) {
return true;
} else {
return next.validateRsv(rsvNext, opCode);
}
}
}
@Override
public Extension getExtensionResponse() {
Extension result = new WsExtension(NAME);
List<Extension.Parameter> params = result.getParameters();
if (!serverContextTakeover) {
params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null));
}
if (serverMaxWindowBits != -1) {
params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS,
Integer.toString(serverMaxWindowBits)));
}
if (!clientContextTakeover) {
params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null));
}
if (clientMaxWindowBits != -1) {
params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS,
Integer.toString(clientMaxWindowBits)));
}
return result;
}
@Override
public void setNext(Transformation t) {
if (next == null) {
this.next = t;
} else {
next.setNext(t);
}
}
@Override
public boolean validateRsvBits(int i) {
if ((i & RSV_BITMASK) != 0) {
return false;
}
if (next == null) {
return true;
} else {
return next.validateRsvBits(i | RSV_BITMASK);
}
}
@Override
public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) {
List<MessagePart> allCompressedParts = new ArrayList<>();
for (MessagePart uncompressedPart : uncompressedParts) {
byte opCode = uncompressedPart.getOpCode();
boolean emptyPart = uncompressedPart.getPayload().limit() == 0;
emptyMessage = emptyMessage && emptyPart;
if (Util.isControl(opCode)) {
// Control messages can appear in the middle of other messages
// and must not be compressed. Pass it straight through
allCompressedParts.add(uncompressedPart);
} else if (emptyMessage && uncompressedPart.isFin()) {
// Zero length messages can't be compressed so pass the
// final (empty) part straight through.
allCompressedParts.add(uncompressedPart);
} else {
List<MessagePart> compressedParts = new ArrayList<>();
ByteBuffer uncompressedPayload = uncompressedPart.getPayload();
SendHandler uncompressedIntermediateHandler =
uncompressedPart.getIntermediateHandler();
deflater.setInput(uncompressedPayload.array(),
uncompressedPayload.arrayOffset() + uncompressedPayload.position(),
uncompressedPayload.remaining());
int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH);
boolean deflateRequired = true;
while (deflateRequired) {
ByteBuffer compressedPayload = writeBuffer;
int written = deflater.deflate(compressedPayload.array(),
compressedPayload.arrayOffset() + compressedPayload.position(),
compressedPayload.remaining(), flush);
compressedPayload.position(compressedPayload.position() + written);
if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) {
// This message part has been fully processed by the
// deflater. Fire the send handler for this message part
// and move on to the next message part.
break;
}
// If this point is reached, a new compressed message part
// will be created...
MessagePart compressedPart;
// .. and a new writeBuffer will be required.
writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
// Flip the compressed payload ready for writing
compressedPayload.flip();
boolean fin = uncompressedPart.isFin();
boolean full = compressedPayload.limit() == compressedPayload.capacity();
boolean needsInput = deflater.needsInput();
long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry();
if (fin && !full && needsInput) {
// End of compressed message. Drop EOM bytes and output.
compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length);
compressedPart = new MessagePart(true, getRsv(uncompressedPart),
opCode, compressedPayload, uncompressedIntermediateHandler,
uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
deflateRequired = false;
startNewMessage();
} else if (full && !needsInput) {
// Write buffer full and input message not fully read.
// Output and start new compressed part.
compressedPart = new MessagePart(false, getRsv(uncompressedPart),
opCode, compressedPayload, uncompressedIntermediateHandler,
uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
} else if (!fin && full && needsInput) {
// Write buffer full and input message not fully read.
// Output and get more data.
compressedPart = new MessagePart(false, getRsv(uncompressedPart),
opCode, compressedPayload, uncompressedIntermediateHandler,
uncompressedIntermediateHandler, blockingWriteTimeoutExpiry);
deflateRequired = false;
} else if (fin && full && needsInput) {
// Write buffer full. Input fully read. Deflater may be
// in one of four states:
// - output complete (just happened to align with end of
// buffer
// - in middle of EOM bytes
// - about to write EOM bytes
// - more data to write
int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH);
if (eomBufferWritten < EOM_BUFFER.length) {
// EOM has just been completed
compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten);
compressedPart = new MessagePart(true,
getRsv(uncompressedPart), opCode, compressedPayload,
uncompressedIntermediateHandler, uncompressedIntermediateHandler,
blockingWriteTimeoutExpiry);
deflateRequired = false;
startNewMessage();
} else {
// More data to write
// Copy bytes to new write buffer
writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten);
compressedPart = new MessagePart(false,
getRsv(uncompressedPart), opCode, compressedPayload,
uncompressedIntermediateHandler, uncompressedIntermediateHandler,
blockingWriteTimeoutExpiry);
}
} else {
throw new IllegalStateException("Should never happen");
}
// Add the newly created compressed part to the set of parts
// to pass on to the next transformation.
compressedParts.add(compressedPart);
}
SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler();
int size = compressedParts.size();
if (size > 0) {
compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler);
}
allCompressedParts.addAll(compressedParts);
}
}
if (next == null) {
return allCompressedParts;
} else {
return next.sendMessagePart(allCompressedParts);
}
}
private void startNewMessage() {
firstCompressedFrameWritten = false;
emptyMessage = true;
if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) {
deflater.reset();
}
}
private int getRsv(MessagePart uncompressedMessagePart) {
int result = uncompressedMessagePart.getRsv();
if (!firstCompressedFrameWritten) {
result += RSV_BITMASK;
firstCompressedFrameWritten = true;
}
return result;
}
@Override
public void close() {
// There will always be a next transformation
next.close();
inflater.end();
deflater.end();
}
}