summaryrefslogtreecommitdiffhomepage
path: root/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java')
-rw-r--r--src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java1234
1 files changed, 1234 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java
new file mode 100644
index 00000000..776124fd
--- /dev/null
+++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java
@@ -0,0 +1,1234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.io.Writer;
+import java.net.SocketTimeoutException;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.charset.CharsetEncoder;
+import java.nio.charset.CoderResult;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.Future;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import javax.websocket.CloseReason;
+import javax.websocket.CloseReason.CloseCodes;
+import javax.websocket.DeploymentException;
+import javax.websocket.EncodeException;
+import javax.websocket.Encoder;
+import javax.websocket.EndpointConfig;
+import javax.websocket.RemoteEndpoint;
+import javax.websocket.SendHandler;
+import javax.websocket.SendResult;
+
+import org.apache.juli.logging.Log;
+import org.apache.juli.logging.LogFactory;
+import org.apache.tomcat.util.buf.Utf8Encoder;
+import org.apache.tomcat.util.res.StringManager;
+
+import nginx.unit.Request;
+
+public abstract class WsRemoteEndpointImplBase implements RemoteEndpoint {
+
+ private static final StringManager sm =
+ StringManager.getManager(WsRemoteEndpointImplBase.class);
+
+ protected static final SendResult SENDRESULT_OK = new SendResult();
+
+ private final Log log = LogFactory.getLog(WsRemoteEndpointImplBase.class); // must not be static
+
+ private final StateMachine stateMachine = new StateMachine();
+
+ private final IntermediateMessageHandler intermediateMessageHandler =
+ new IntermediateMessageHandler(this);
+
+ private Transformation transformation = null;
+ private final Semaphore messagePartInProgress = new Semaphore(1);
+ private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>();
+ private final Object messagePartLock = new Object();
+
+ // State
+ private volatile boolean closed = false;
+ private boolean fragmented = false;
+ private boolean nextFragmented = false;
+ private boolean text = false;
+ private boolean nextText = false;
+
+ // Max size of WebSocket header is 14 bytes
+ private final ByteBuffer headerBuffer = ByteBuffer.allocate(14);
+ private final ByteBuffer outputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final CharsetEncoder encoder = new Utf8Encoder();
+ private final ByteBuffer encoderBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final AtomicBoolean batchingAllowed = new AtomicBoolean(false);
+ private volatile long sendTimeout = -1;
+ private WsSession wsSession;
+ private List<EncoderEntry> encoderEntries = new ArrayList<>();
+
+ private Request request;
+
+
+ protected void setTransformation(Transformation transformation) {
+ this.transformation = transformation;
+ }
+
+
+ public long getSendTimeout() {
+ return sendTimeout;
+ }
+
+
+ public void setSendTimeout(long timeout) {
+ this.sendTimeout = timeout;
+ }
+
+
+ @Override
+ public void setBatchingAllowed(boolean batchingAllowed) throws IOException {
+ boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed);
+
+ if (oldValue && !batchingAllowed) {
+ flushBatch();
+ }
+ }
+
+
+ @Override
+ public boolean getBatchingAllowed() {
+ return batchingAllowed.get();
+ }
+
+
+ @Override
+ public void flushBatch() throws IOException {
+ sendMessageBlock(Constants.INTERNAL_OPCODE_FLUSH, null, true);
+ }
+
+
+ public void sendBytes(ByteBuffer data) throws IOException {
+ if (data == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ stateMachine.binaryStart();
+ sendMessageBlock(Constants.OPCODE_BINARY, data, true);
+ stateMachine.complete(true);
+ }
+
+
+ public Future<Void> sendBytesByFuture(ByteBuffer data) {
+ FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
+ sendBytesByCompletion(data, f2sh);
+ return f2sh;
+ }
+
+
+ public void sendBytesByCompletion(ByteBuffer data, SendHandler handler) {
+ if (data == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ if (handler == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
+ }
+ StateUpdateSendHandler sush = new StateUpdateSendHandler(handler, stateMachine);
+ stateMachine.binaryStart();
+ startMessage(Constants.OPCODE_BINARY, data, true, sush);
+ }
+
+
+ public void sendPartialBytes(ByteBuffer partialByte, boolean last)
+ throws IOException {
+ if (partialByte == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ stateMachine.binaryPartialStart();
+ sendMessageBlock(Constants.OPCODE_BINARY, partialByte, last);
+ stateMachine.complete(last);
+ }
+
+
+ @Override
+ public void sendPing(ByteBuffer applicationData) throws IOException,
+ IllegalArgumentException {
+ if (applicationData.remaining() > 125) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData"));
+ }
+ sendMessageBlock(Constants.OPCODE_PING, applicationData, true);
+ }
+
+
+ @Override
+ public void sendPong(ByteBuffer applicationData) throws IOException,
+ IllegalArgumentException {
+ if (applicationData.remaining() > 125) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData"));
+ }
+ sendMessageBlock(Constants.OPCODE_PONG, applicationData, true);
+ }
+
+
+ public void sendString(String text) throws IOException {
+ if (text == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ stateMachine.textStart();
+ sendMessageBlock(CharBuffer.wrap(text), true);
+ }
+
+
+ public Future<Void> sendStringByFuture(String text) {
+ FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
+ sendStringByCompletion(text, f2sh);
+ return f2sh;
+ }
+
+
+ public void sendStringByCompletion(String text, SendHandler handler) {
+ if (text == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ if (handler == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
+ }
+ stateMachine.textStart();
+ TextMessageSendHandler tmsh = new TextMessageSendHandler(handler,
+ CharBuffer.wrap(text), true, encoder, encoderBuffer, this);
+ tmsh.write();
+ // TextMessageSendHandler will update stateMachine when it completes
+ }
+
+
+ public void sendPartialString(String fragment, boolean isLast)
+ throws IOException {
+ if (fragment == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ stateMachine.textPartialStart();
+ sendMessageBlock(CharBuffer.wrap(fragment), isLast);
+ }
+
+
+ public OutputStream getSendStream() {
+ stateMachine.streamStart();
+ return new WsOutputStream(this);
+ }
+
+
+ public Writer getSendWriter() {
+ stateMachine.writeStart();
+ return new WsWriter(this);
+ }
+
+
+ void sendMessageBlock(CharBuffer part, boolean last) throws IOException {
+ long timeoutExpiry = getTimeoutExpiry();
+ boolean isDone = false;
+ while (!isDone) {
+ encoderBuffer.clear();
+ CoderResult cr = encoder.encode(part, encoderBuffer, true);
+ if (cr.isError()) {
+ throw new IllegalArgumentException(cr.toString());
+ }
+ isDone = !cr.isOverflow();
+ encoderBuffer.flip();
+ sendMessageBlock(Constants.OPCODE_TEXT, encoderBuffer, last && isDone, timeoutExpiry);
+ }
+ stateMachine.complete(last);
+ }
+
+
+ void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last)
+ throws IOException {
+ sendMessageBlock(opCode, payload, last, getTimeoutExpiry());
+ }
+
+
+ private long getTimeoutExpiry() {
+ // Get the timeout before we send the message. The message may
+ // trigger a session close and depending on timing the client
+ // session may close before we can read the timeout.
+ long timeout = getBlockingSendTimeout();
+ if (timeout < 0) {
+ return Long.MAX_VALUE;
+ } else {
+ return System.currentTimeMillis() + timeout;
+ }
+ }
+
+ private byte currentOpCode = Constants.OPCODE_CONTINUATION;
+
+ private void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last,
+ long timeoutExpiry) throws IOException {
+ wsSession.updateLastActive();
+
+ if (opCode == currentOpCode) {
+ opCode = Constants.OPCODE_CONTINUATION;
+ }
+
+ request.sendWsFrame(payload, opCode, last, timeoutExpiry);
+
+ if (!last && opCode != Constants.OPCODE_CONTINUATION) {
+ currentOpCode = opCode;
+ }
+
+ if (last && opCode == Constants.OPCODE_CONTINUATION) {
+ currentOpCode = Constants.OPCODE_CONTINUATION;
+ }
+ }
+
+
+ void startMessage(byte opCode, ByteBuffer payload, boolean last,
+ SendHandler handler) {
+
+ wsSession.updateLastActive();
+
+ List<MessagePart> messageParts = new ArrayList<>();
+ messageParts.add(new MessagePart(last, 0, opCode, payload,
+ intermediateMessageHandler,
+ new EndMessageHandler(this, handler), -1));
+
+ messageParts = transformation.sendMessagePart(messageParts);
+
+ // Some extensions/transformations may buffer messages so it is possible
+ // that no message parts will be returned. If this is the case the
+ // trigger the supplied SendHandler
+ if (messageParts.size() == 0) {
+ handler.onResult(new SendResult());
+ return;
+ }
+
+ MessagePart mp = messageParts.remove(0);
+
+ boolean doWrite = false;
+ synchronized (messagePartLock) {
+ if (Constants.OPCODE_CLOSE == mp.getOpCode() && getBatchingAllowed()) {
+ // Should not happen. To late to send batched messages now since
+ // the session has been closed. Complain loudly.
+ log.warn(sm.getString("wsRemoteEndpoint.flushOnCloseFailed"));
+ }
+ if (messagePartInProgress.tryAcquire()) {
+ doWrite = true;
+ } else {
+ // When a control message is sent while another message is being
+ // sent, the control message is queued. Chances are the
+ // subsequent data message part will end up queued while the
+ // control message is sent. The logic in this class (state
+ // machine, EndMessageHandler, TextMessageSendHandler) ensures
+ // that there will only ever be one data message part in the
+ // queue. There could be multiple control messages in the queue.
+
+ // Add it to the queue
+ messagePartQueue.add(mp);
+ }
+ // Add any remaining messages to the queue
+ messagePartQueue.addAll(messageParts);
+ }
+ if (doWrite) {
+ // Actual write has to be outside sync block to avoid possible
+ // deadlock between messagePartLock and writeLock in
+ // o.a.coyote.http11.upgrade.AbstractServletOutputStream
+ writeMessagePart(mp);
+ }
+ }
+
+
+ void endMessage(SendHandler handler, SendResult result) {
+ boolean doWrite = false;
+ MessagePart mpNext = null;
+ synchronized (messagePartLock) {
+
+ fragmented = nextFragmented;
+ text = nextText;
+
+ mpNext = messagePartQueue.poll();
+ if (mpNext == null) {
+ messagePartInProgress.release();
+ } else if (!closed){
+ // Session may have been closed unexpectedly in the middle of
+ // sending a fragmented message closing the endpoint. If this
+ // happens, clearly there is no point trying to send the rest of
+ // the message.
+ doWrite = true;
+ }
+ }
+ if (doWrite) {
+ // Actual write has to be outside sync block to avoid possible
+ // deadlock between messagePartLock and writeLock in
+ // o.a.coyote.http11.upgrade.AbstractServletOutputStream
+ writeMessagePart(mpNext);
+ }
+
+ wsSession.updateLastActive();
+
+ // Some handlers, such as the IntermediateMessageHandler, do not have a
+ // nested handler so handler may be null.
+ if (handler != null) {
+ handler.onResult(result);
+ }
+ }
+
+
+ void writeMessagePart(MessagePart mp) {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closed"));
+ }
+
+ if (Constants.INTERNAL_OPCODE_FLUSH == mp.getOpCode()) {
+ nextFragmented = fragmented;
+ nextText = text;
+ outputBuffer.flip();
+ SendHandler flushHandler = new OutputBufferFlushSendHandler(
+ outputBuffer, mp.getEndHandler());
+ doWrite(flushHandler, mp.getBlockingWriteTimeoutExpiry(), outputBuffer);
+ return;
+ }
+
+ // Control messages may be sent in the middle of fragmented message
+ // so they have no effect on the fragmented or text flags
+ boolean first;
+ if (Util.isControl(mp.getOpCode())) {
+ nextFragmented = fragmented;
+ nextText = text;
+ if (mp.getOpCode() == Constants.OPCODE_CLOSE) {
+ closed = true;
+ }
+ first = true;
+ } else {
+ boolean isText = Util.isText(mp.getOpCode());
+
+ if (fragmented) {
+ // Currently fragmented
+ if (text != isText) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.changeType"));
+ }
+ nextText = text;
+ nextFragmented = !mp.isFin();
+ first = false;
+ } else {
+ // Wasn't fragmented. Might be now
+ if (mp.isFin()) {
+ nextFragmented = false;
+ } else {
+ nextFragmented = true;
+ nextText = isText;
+ }
+ first = true;
+ }
+ }
+
+ byte[] mask;
+
+ if (isMasked()) {
+ mask = Util.generateMask();
+ } else {
+ mask = null;
+ }
+
+ headerBuffer.clear();
+ writeHeader(headerBuffer, mp.isFin(), mp.getRsv(), mp.getOpCode(),
+ isMasked(), mp.getPayload(), mask, first);
+ headerBuffer.flip();
+
+ if (getBatchingAllowed() || isMasked()) {
+ // Need to write via output buffer
+ OutputBufferSendHandler obsh = new OutputBufferSendHandler(
+ mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(),
+ headerBuffer, mp.getPayload(), mask,
+ outputBuffer, !getBatchingAllowed(), this);
+ obsh.write();
+ } else {
+ // Can write directly
+ doWrite(mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(),
+ headerBuffer, mp.getPayload());
+ }
+ }
+
+
+ private long getBlockingSendTimeout() {
+ Object obj = wsSession.getUserProperties().get(Constants.BLOCKING_SEND_TIMEOUT_PROPERTY);
+ Long userTimeout = null;
+ if (obj instanceof Long) {
+ userTimeout = (Long) obj;
+ }
+ if (userTimeout == null) {
+ return Constants.DEFAULT_BLOCKING_SEND_TIMEOUT;
+ } else {
+ return userTimeout.longValue();
+ }
+ }
+
+
+ /**
+ * Wraps the user provided handler so that the end point is notified when
+ * the message is complete.
+ */
+ private static class EndMessageHandler implements SendHandler {
+
+ private final WsRemoteEndpointImplBase endpoint;
+ private final SendHandler handler;
+
+ public EndMessageHandler(WsRemoteEndpointImplBase endpoint,
+ SendHandler handler) {
+ this.endpoint = endpoint;
+ this.handler = handler;
+ }
+
+
+ @Override
+ public void onResult(SendResult result) {
+ endpoint.endMessage(handler, result);
+ }
+ }
+
+
+ /**
+ * If a transformation needs to split a {@link MessagePart} into multiple
+ * {@link MessagePart}s, it uses this handler as the end handler for each of
+ * the additional {@link MessagePart}s. This handler notifies this this
+ * class that the {@link MessagePart} has been processed and that the next
+ * {@link MessagePart} in the queue should be started. The final
+ * {@link MessagePart} will use the {@link EndMessageHandler} provided with
+ * the original {@link MessagePart}.
+ */
+ private static class IntermediateMessageHandler implements SendHandler {
+
+ private final WsRemoteEndpointImplBase endpoint;
+
+ public IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint) {
+ this.endpoint = endpoint;
+ }
+
+
+ @Override
+ public void onResult(SendResult result) {
+ endpoint.endMessage(null, result);
+ }
+ }
+
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public void sendObject(Object obj) throws IOException, EncodeException {
+ if (obj == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ /*
+ * Note that the implementation will convert primitives and their object
+ * equivalents by default but that users are free to specify their own
+ * encoders and decoders for this if they wish.
+ */
+ Encoder encoder = findEncoder(obj);
+ if (encoder == null && Util.isPrimitive(obj.getClass())) {
+ String msg = obj.toString();
+ sendString(msg);
+ return;
+ }
+ if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) {
+ ByteBuffer msg = ByteBuffer.wrap((byte[]) obj);
+ sendBytes(msg);
+ return;
+ }
+
+ if (encoder instanceof Encoder.Text) {
+ String msg = ((Encoder.Text) encoder).encode(obj);
+ sendString(msg);
+ } else if (encoder instanceof Encoder.TextStream) {
+ try (Writer w = getSendWriter()) {
+ ((Encoder.TextStream) encoder).encode(obj, w);
+ }
+ } else if (encoder instanceof Encoder.Binary) {
+ ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj);
+ sendBytes(msg);
+ } else if (encoder instanceof Encoder.BinaryStream) {
+ try (OutputStream os = getSendStream()) {
+ ((Encoder.BinaryStream) encoder).encode(obj, os);
+ }
+ } else {
+ throw new EncodeException(obj, sm.getString(
+ "wsRemoteEndpoint.noEncoder", obj.getClass()));
+ }
+ }
+
+
+ public Future<Void> sendObjectByFuture(Object obj) {
+ FutureToSendHandler f2sh = new FutureToSendHandler(wsSession);
+ sendObjectByCompletion(obj, f2sh);
+ return f2sh;
+ }
+
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public void sendObjectByCompletion(Object obj, SendHandler completion) {
+
+ if (obj == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData"));
+ }
+ if (completion == null) {
+ throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler"));
+ }
+
+ /*
+ * Note that the implementation will convert primitives and their object
+ * equivalents by default but that users are free to specify their own
+ * encoders and decoders for this if they wish.
+ */
+ Encoder encoder = findEncoder(obj);
+ if (encoder == null && Util.isPrimitive(obj.getClass())) {
+ String msg = obj.toString();
+ sendStringByCompletion(msg, completion);
+ return;
+ }
+ if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) {
+ ByteBuffer msg = ByteBuffer.wrap((byte[]) obj);
+ sendBytesByCompletion(msg, completion);
+ return;
+ }
+
+ try {
+ if (encoder instanceof Encoder.Text) {
+ String msg = ((Encoder.Text) encoder).encode(obj);
+ sendStringByCompletion(msg, completion);
+ } else if (encoder instanceof Encoder.TextStream) {
+ try (Writer w = getSendWriter()) {
+ ((Encoder.TextStream) encoder).encode(obj, w);
+ }
+ completion.onResult(new SendResult());
+ } else if (encoder instanceof Encoder.Binary) {
+ ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj);
+ sendBytesByCompletion(msg, completion);
+ } else if (encoder instanceof Encoder.BinaryStream) {
+ try (OutputStream os = getSendStream()) {
+ ((Encoder.BinaryStream) encoder).encode(obj, os);
+ }
+ completion.onResult(new SendResult());
+ } else {
+ throw new EncodeException(obj, sm.getString(
+ "wsRemoteEndpoint.noEncoder", obj.getClass()));
+ }
+ } catch (Exception e) {
+ SendResult sr = new SendResult(e);
+ completion.onResult(sr);
+ }
+ }
+
+
+ protected void setSession(WsSession wsSession) {
+ this.wsSession = wsSession;
+ }
+
+
+ protected void setRequest(Request request) {
+ this.request = request;
+ }
+
+ protected void setEncoders(EndpointConfig endpointConfig)
+ throws DeploymentException {
+ encoderEntries.clear();
+ for (Class<? extends Encoder> encoderClazz :
+ endpointConfig.getEncoders()) {
+ Encoder instance;
+ try {
+ instance = encoderClazz.getConstructor().newInstance();
+ instance.init(endpointConfig);
+ } catch (ReflectiveOperationException e) {
+ throw new DeploymentException(
+ sm.getString("wsRemoteEndpoint.invalidEncoder",
+ encoderClazz.getName()), e);
+ }
+ EncoderEntry entry = new EncoderEntry(
+ Util.getEncoderType(encoderClazz), instance);
+ encoderEntries.add(entry);
+ }
+ }
+
+
+ private Encoder findEncoder(Object obj) {
+ for (EncoderEntry entry : encoderEntries) {
+ if (entry.getClazz().isAssignableFrom(obj.getClass())) {
+ return entry.getEncoder();
+ }
+ }
+ return null;
+ }
+
+
+ public final void close() {
+ for (EncoderEntry entry : encoderEntries) {
+ entry.getEncoder().destroy();
+ }
+
+ request.closeWs();
+ }
+
+
+ protected abstract void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry,
+ ByteBuffer... data);
+ protected abstract boolean isMasked();
+ protected abstract void doClose();
+
+ private static void writeHeader(ByteBuffer headerBuffer, boolean fin,
+ int rsv, byte opCode, boolean masked, ByteBuffer payload,
+ byte[] mask, boolean first) {
+
+ byte b = 0;
+
+ if (fin) {
+ // Set the fin bit
+ b -= 128;
+ }
+
+ b += (rsv << 4);
+
+ if (first) {
+ // This is the first fragment of this message
+ b += opCode;
+ }
+ // If not the first fragment, it is a continuation with opCode of zero
+
+ headerBuffer.put(b);
+
+ if (masked) {
+ b = (byte) 0x80;
+ } else {
+ b = 0;
+ }
+
+ // Next write the mask && length length
+ if (payload.limit() < 126) {
+ headerBuffer.put((byte) (payload.limit() | b));
+ } else if (payload.limit() < 65536) {
+ headerBuffer.put((byte) (126 | b));
+ headerBuffer.put((byte) (payload.limit() >>> 8));
+ headerBuffer.put((byte) (payload.limit() & 0xFF));
+ } else {
+ // Will never be more than 2^31-1
+ headerBuffer.put((byte) (127 | b));
+ headerBuffer.put((byte) 0);
+ headerBuffer.put((byte) 0);
+ headerBuffer.put((byte) 0);
+ headerBuffer.put((byte) 0);
+ headerBuffer.put((byte) (payload.limit() >>> 24));
+ headerBuffer.put((byte) (payload.limit() >>> 16));
+ headerBuffer.put((byte) (payload.limit() >>> 8));
+ headerBuffer.put((byte) (payload.limit() & 0xFF));
+ }
+ if (masked) {
+ headerBuffer.put(mask[0]);
+ headerBuffer.put(mask[1]);
+ headerBuffer.put(mask[2]);
+ headerBuffer.put(mask[3]);
+ }
+ }
+
+
+ private class TextMessageSendHandler implements SendHandler {
+
+ private final SendHandler handler;
+ private final CharBuffer message;
+ private final boolean isLast;
+ private final CharsetEncoder encoder;
+ private final ByteBuffer buffer;
+ private final WsRemoteEndpointImplBase endpoint;
+ private volatile boolean isDone = false;
+
+ public TextMessageSendHandler(SendHandler handler, CharBuffer message,
+ boolean isLast, CharsetEncoder encoder,
+ ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint) {
+ this.handler = handler;
+ this.message = message;
+ this.isLast = isLast;
+ this.encoder = encoder.reset();
+ this.buffer = encoderBuffer;
+ this.endpoint = endpoint;
+ }
+
+ public void write() {
+ buffer.clear();
+ CoderResult cr = encoder.encode(message, buffer, true);
+ if (cr.isError()) {
+ throw new IllegalArgumentException(cr.toString());
+ }
+ isDone = !cr.isOverflow();
+ buffer.flip();
+ endpoint.startMessage(Constants.OPCODE_TEXT, buffer,
+ isDone && isLast, this);
+ }
+
+ @Override
+ public void onResult(SendResult result) {
+ if (isDone) {
+ endpoint.stateMachine.complete(isLast);
+ handler.onResult(result);
+ } else if(!result.isOK()) {
+ handler.onResult(result);
+ } else if (closed){
+ SendResult sr = new SendResult(new IOException(
+ sm.getString("wsRemoteEndpoint.closedDuringMessage")));
+ handler.onResult(sr);
+ } else {
+ write();
+ }
+ }
+ }
+
+
+ /**
+ * Used to write data to the output buffer, flushing the buffer if it fills
+ * up.
+ */
+ private static class OutputBufferSendHandler implements SendHandler {
+
+ private final SendHandler handler;
+ private final long blockingWriteTimeoutExpiry;
+ private final ByteBuffer headerBuffer;
+ private final ByteBuffer payload;
+ private final byte[] mask;
+ private final ByteBuffer outputBuffer;
+ private final boolean flushRequired;
+ private final WsRemoteEndpointImplBase endpoint;
+ private int maskIndex = 0;
+
+ public OutputBufferSendHandler(SendHandler completion,
+ long blockingWriteTimeoutExpiry,
+ ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask,
+ ByteBuffer outputBuffer, boolean flushRequired,
+ WsRemoteEndpointImplBase endpoint) {
+ this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry;
+ this.handler = completion;
+ this.headerBuffer = headerBuffer;
+ this.payload = payload;
+ this.mask = mask;
+ this.outputBuffer = outputBuffer;
+ this.flushRequired = flushRequired;
+ this.endpoint = endpoint;
+ }
+
+ public void write() {
+ // Write the header
+ while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) {
+ outputBuffer.put(headerBuffer.get());
+ }
+ if (headerBuffer.hasRemaining()) {
+ // Still more headers to write, need to flush
+ outputBuffer.flip();
+ endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
+ return;
+ }
+
+ // Write the payload
+ int payloadLeft = payload.remaining();
+ int payloadLimit = payload.limit();
+ int outputSpace = outputBuffer.remaining();
+ int toWrite = payloadLeft;
+
+ if (payloadLeft > outputSpace) {
+ toWrite = outputSpace;
+ // Temporarily reduce the limit
+ payload.limit(payload.position() + toWrite);
+ }
+
+ if (mask == null) {
+ // Use a bulk copy
+ outputBuffer.put(payload);
+ } else {
+ for (int i = 0; i < toWrite; i++) {
+ outputBuffer.put(
+ (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF)));
+ if (maskIndex > 3) {
+ maskIndex = 0;
+ }
+ }
+ }
+
+ if (payloadLeft > outputSpace) {
+ // Restore the original limit
+ payload.limit(payloadLimit);
+ // Still more data to write, need to flush
+ outputBuffer.flip();
+ endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
+ return;
+ }
+
+ if (flushRequired) {
+ outputBuffer.flip();
+ if (outputBuffer.remaining() == 0) {
+ handler.onResult(SENDRESULT_OK);
+ } else {
+ endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
+ }
+ } else {
+ handler.onResult(SENDRESULT_OK);
+ }
+ }
+
+ // ------------------------------------------------- SendHandler methods
+ @Override
+ public void onResult(SendResult result) {
+ if (result.isOK()) {
+ if (outputBuffer.hasRemaining()) {
+ endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer);
+ } else {
+ outputBuffer.clear();
+ write();
+ }
+ } else {
+ handler.onResult(result);
+ }
+ }
+ }
+
+
+ /**
+ * Ensures that the output buffer is cleared after it has been flushed.
+ */
+ private static class OutputBufferFlushSendHandler implements SendHandler {
+
+ private final ByteBuffer outputBuffer;
+ private final SendHandler handler;
+
+ public OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler) {
+ this.outputBuffer = outputBuffer;
+ this.handler = handler;
+ }
+
+ @Override
+ public void onResult(SendResult result) {
+ if (result.isOK()) {
+ outputBuffer.clear();
+ }
+ handler.onResult(result);
+ }
+ }
+
+
+ private static class WsOutputStream extends OutputStream {
+
+ private final WsRemoteEndpointImplBase endpoint;
+ private final ByteBuffer buffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final Object closeLock = new Object();
+ private volatile boolean closed = false;
+ private volatile boolean used = false;
+
+ public WsOutputStream(WsRemoteEndpointImplBase endpoint) {
+ this.endpoint = endpoint;
+ }
+
+ @Override
+ public void write(int b) throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedOutputStream"));
+ }
+
+ used = true;
+ if (buffer.remaining() == 0) {
+ flush();
+ }
+ buffer.put((byte) b);
+ }
+
+ @Override
+ public void write(byte[] b, int off, int len) throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedOutputStream"));
+ }
+ if (len == 0) {
+ return;
+ }
+ if ((off < 0) || (off > b.length) || (len < 0) ||
+ ((off + len) > b.length) || ((off + len) < 0)) {
+ throw new IndexOutOfBoundsException();
+ }
+
+ used = true;
+ if (buffer.remaining() == 0) {
+ flush();
+ }
+ int remaining = buffer.remaining();
+ int written = 0;
+
+ while (remaining < len - written) {
+ buffer.put(b, off + written, remaining);
+ written += remaining;
+ flush();
+ remaining = buffer.remaining();
+ }
+ buffer.put(b, off + written, len - written);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedOutputStream"));
+ }
+
+ // Optimisation. If there is no data to flush then do not send an
+ // empty message.
+ if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) {
+ doWrite(false);
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ synchronized (closeLock) {
+ if (closed) {
+ return;
+ }
+ closed = true;
+ }
+
+ doWrite(true);
+ }
+
+ private void doWrite(boolean last) throws IOException {
+ if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) {
+ buffer.flip();
+ endpoint.sendMessageBlock(Constants.OPCODE_BINARY, buffer, last);
+ }
+ endpoint.stateMachine.complete(last);
+ buffer.clear();
+ }
+ }
+
+
+ private static class WsWriter extends Writer {
+
+ private final WsRemoteEndpointImplBase endpoint;
+ private final CharBuffer buffer = CharBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE);
+ private final Object closeLock = new Object();
+ private volatile boolean closed = false;
+ private volatile boolean used = false;
+
+ public WsWriter(WsRemoteEndpointImplBase endpoint) {
+ this.endpoint = endpoint;
+ }
+
+ @Override
+ public void write(char[] cbuf, int off, int len) throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedWriter"));
+ }
+ if (len == 0) {
+ return;
+ }
+ if ((off < 0) || (off > cbuf.length) || (len < 0) ||
+ ((off + len) > cbuf.length) || ((off + len) < 0)) {
+ throw new IndexOutOfBoundsException();
+ }
+
+ used = true;
+ if (buffer.remaining() == 0) {
+ flush();
+ }
+ int remaining = buffer.remaining();
+ int written = 0;
+
+ while (remaining < len - written) {
+ buffer.put(cbuf, off + written, remaining);
+ written += remaining;
+ flush();
+ remaining = buffer.remaining();
+ }
+ buffer.put(cbuf, off + written, len - written);
+ }
+
+ @Override
+ public void flush() throws IOException {
+ if (closed) {
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.closedWriter"));
+ }
+
+ if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) {
+ doWrite(false);
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ synchronized (closeLock) {
+ if (closed) {
+ return;
+ }
+ closed = true;
+ }
+
+ doWrite(true);
+ }
+
+ private void doWrite(boolean last) throws IOException {
+ if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) {
+ buffer.flip();
+ endpoint.sendMessageBlock(buffer, last);
+ buffer.clear();
+ } else {
+ endpoint.stateMachine.complete(last);
+ }
+ }
+ }
+
+
+ private static class EncoderEntry {
+
+ private final Class<?> clazz;
+ private final Encoder encoder;
+
+ public EncoderEntry(Class<?> clazz, Encoder encoder) {
+ this.clazz = clazz;
+ this.encoder = encoder;
+ }
+
+ public Class<?> getClazz() {
+ return clazz;
+ }
+
+ public Encoder getEncoder() {
+ return encoder;
+ }
+ }
+
+
+ private enum State {
+ OPEN,
+ STREAM_WRITING,
+ WRITER_WRITING,
+ BINARY_PARTIAL_WRITING,
+ BINARY_PARTIAL_READY,
+ BINARY_FULL_WRITING,
+ TEXT_PARTIAL_WRITING,
+ TEXT_PARTIAL_READY,
+ TEXT_FULL_WRITING
+ }
+
+
+ private static class StateMachine {
+ private State state = State.OPEN;
+
+ public synchronized void streamStart() {
+ checkState(State.OPEN);
+ state = State.STREAM_WRITING;
+ }
+
+ public synchronized void writeStart() {
+ checkState(State.OPEN);
+ state = State.WRITER_WRITING;
+ }
+
+ public synchronized void binaryPartialStart() {
+ checkState(State.OPEN, State.BINARY_PARTIAL_READY);
+ state = State.BINARY_PARTIAL_WRITING;
+ }
+
+ public synchronized void binaryStart() {
+ checkState(State.OPEN);
+ state = State.BINARY_FULL_WRITING;
+ }
+
+ public synchronized void textPartialStart() {
+ checkState(State.OPEN, State.TEXT_PARTIAL_READY);
+ state = State.TEXT_PARTIAL_WRITING;
+ }
+
+ public synchronized void textStart() {
+ checkState(State.OPEN);
+ state = State.TEXT_FULL_WRITING;
+ }
+
+ public synchronized void complete(boolean last) {
+ if (last) {
+ checkState(State.TEXT_PARTIAL_WRITING, State.TEXT_FULL_WRITING,
+ State.BINARY_PARTIAL_WRITING, State.BINARY_FULL_WRITING,
+ State.STREAM_WRITING, State.WRITER_WRITING);
+ state = State.OPEN;
+ } else {
+ checkState(State.TEXT_PARTIAL_WRITING, State.BINARY_PARTIAL_WRITING,
+ State.STREAM_WRITING, State.WRITER_WRITING);
+ if (state == State.TEXT_PARTIAL_WRITING) {
+ state = State.TEXT_PARTIAL_READY;
+ } else if (state == State.BINARY_PARTIAL_WRITING){
+ state = State.BINARY_PARTIAL_READY;
+ } else if (state == State.WRITER_WRITING) {
+ // NO-OP. Leave state as is.
+ } else if (state == State.STREAM_WRITING) {
+ // NO-OP. Leave state as is.
+ } else {
+ // Should never happen
+ // The if ... else ... blocks above should cover all states
+ // permitted by the preceding checkState() call
+ throw new IllegalStateException(
+ "BUG: This code should never be called");
+ }
+ }
+ }
+
+ private void checkState(State... required) {
+ for (State state : required) {
+ if (this.state == state) {
+ return;
+ }
+ }
+ throw new IllegalStateException(
+ sm.getString("wsRemoteEndpoint.wrongState", this.state));
+ }
+ }
+
+
+ private static class StateUpdateSendHandler implements SendHandler {
+
+ private final SendHandler handler;
+ private final StateMachine stateMachine;
+
+ public StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine) {
+ this.handler = handler;
+ this.stateMachine = stateMachine;
+ }
+
+ @Override
+ public void onResult(SendResult result) {
+ if (result.isOK()) {
+ stateMachine.complete(true);
+ }
+ handler.onResult(result);
+ }
+ }
+
+
+ private static class BlockingSendHandler implements SendHandler {
+
+ private SendResult sendResult = null;
+
+ @Override
+ public void onResult(SendResult result) {
+ sendResult = result;
+ }
+
+ public SendResult getSendResult() {
+ return sendResult;
+ }
+ }
+}