diff options
Diffstat (limited to 'src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java')
-rw-r--r-- | src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java | 1234 |
1 files changed, 1234 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java new file mode 100644 index 00000000..776124fd --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java @@ -0,0 +1,1234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CoderResult; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.EncodeException; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.buf.Utf8Encoder; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.Request; + +public abstract class WsRemoteEndpointImplBase implements RemoteEndpoint { + + private static final StringManager sm = + StringManager.getManager(WsRemoteEndpointImplBase.class); + + protected static final SendResult SENDRESULT_OK = new SendResult(); + + private final Log log = LogFactory.getLog(WsRemoteEndpointImplBase.class); // must not be static + + private final StateMachine stateMachine = new StateMachine(); + + private final IntermediateMessageHandler intermediateMessageHandler = + new IntermediateMessageHandler(this); + + private Transformation transformation = null; + private final Semaphore messagePartInProgress = new Semaphore(1); + private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>(); + private final Object messagePartLock = new Object(); + + // State + private volatile boolean closed = false; + private boolean fragmented = false; + private boolean nextFragmented = false; + private boolean text = false; + private boolean nextText = false; + + // Max size of WebSocket header is 14 bytes + private final ByteBuffer headerBuffer = ByteBuffer.allocate(14); + private final ByteBuffer outputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final CharsetEncoder encoder = new Utf8Encoder(); + private final ByteBuffer encoderBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final AtomicBoolean batchingAllowed = new AtomicBoolean(false); + private volatile long sendTimeout = -1; + private WsSession wsSession; + private List<EncoderEntry> encoderEntries = new ArrayList<>(); + + private Request request; + + + protected void setTransformation(Transformation transformation) { + this.transformation = transformation; + } + + + public long getSendTimeout() { + return sendTimeout; + } + + + public void setSendTimeout(long timeout) { + this.sendTimeout = timeout; + } + + + @Override + public void setBatchingAllowed(boolean batchingAllowed) throws IOException { + boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed); + + if (oldValue && !batchingAllowed) { + flushBatch(); + } + } + + + @Override + public boolean getBatchingAllowed() { + return batchingAllowed.get(); + } + + + @Override + public void flushBatch() throws IOException { + sendMessageBlock(Constants.INTERNAL_OPCODE_FLUSH, null, true); + } + + + public void sendBytes(ByteBuffer data) throws IOException { + if (data == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.binaryStart(); + sendMessageBlock(Constants.OPCODE_BINARY, data, true); + stateMachine.complete(true); + } + + + public Future<Void> sendBytesByFuture(ByteBuffer data) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendBytesByCompletion(data, f2sh); + return f2sh; + } + + + public void sendBytesByCompletion(ByteBuffer data, SendHandler handler) { + if (data == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (handler == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + StateUpdateSendHandler sush = new StateUpdateSendHandler(handler, stateMachine); + stateMachine.binaryStart(); + startMessage(Constants.OPCODE_BINARY, data, true, sush); + } + + + public void sendPartialBytes(ByteBuffer partialByte, boolean last) + throws IOException { + if (partialByte == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.binaryPartialStart(); + sendMessageBlock(Constants.OPCODE_BINARY, partialByte, last); + stateMachine.complete(last); + } + + + @Override + public void sendPing(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + if (applicationData.remaining() > 125) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); + } + sendMessageBlock(Constants.OPCODE_PING, applicationData, true); + } + + + @Override + public void sendPong(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + if (applicationData.remaining() > 125) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); + } + sendMessageBlock(Constants.OPCODE_PONG, applicationData, true); + } + + + public void sendString(String text) throws IOException { + if (text == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.textStart(); + sendMessageBlock(CharBuffer.wrap(text), true); + } + + + public Future<Void> sendStringByFuture(String text) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendStringByCompletion(text, f2sh); + return f2sh; + } + + + public void sendStringByCompletion(String text, SendHandler handler) { + if (text == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (handler == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + stateMachine.textStart(); + TextMessageSendHandler tmsh = new TextMessageSendHandler(handler, + CharBuffer.wrap(text), true, encoder, encoderBuffer, this); + tmsh.write(); + // TextMessageSendHandler will update stateMachine when it completes + } + + + public void sendPartialString(String fragment, boolean isLast) + throws IOException { + if (fragment == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.textPartialStart(); + sendMessageBlock(CharBuffer.wrap(fragment), isLast); + } + + + public OutputStream getSendStream() { + stateMachine.streamStart(); + return new WsOutputStream(this); + } + + + public Writer getSendWriter() { + stateMachine.writeStart(); + return new WsWriter(this); + } + + + void sendMessageBlock(CharBuffer part, boolean last) throws IOException { + long timeoutExpiry = getTimeoutExpiry(); + boolean isDone = false; + while (!isDone) { + encoderBuffer.clear(); + CoderResult cr = encoder.encode(part, encoderBuffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + encoderBuffer.flip(); + sendMessageBlock(Constants.OPCODE_TEXT, encoderBuffer, last && isDone, timeoutExpiry); + } + stateMachine.complete(last); + } + + + void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last) + throws IOException { + sendMessageBlock(opCode, payload, last, getTimeoutExpiry()); + } + + + private long getTimeoutExpiry() { + // Get the timeout before we send the message. The message may + // trigger a session close and depending on timing the client + // session may close before we can read the timeout. + long timeout = getBlockingSendTimeout(); + if (timeout < 0) { + return Long.MAX_VALUE; + } else { + return System.currentTimeMillis() + timeout; + } + } + + private byte currentOpCode = Constants.OPCODE_CONTINUATION; + + private void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last, + long timeoutExpiry) throws IOException { + wsSession.updateLastActive(); + + if (opCode == currentOpCode) { + opCode = Constants.OPCODE_CONTINUATION; + } + + request.sendWsFrame(payload, opCode, last, timeoutExpiry); + + if (!last && opCode != Constants.OPCODE_CONTINUATION) { + currentOpCode = opCode; + } + + if (last && opCode == Constants.OPCODE_CONTINUATION) { + currentOpCode = Constants.OPCODE_CONTINUATION; + } + } + + + void startMessage(byte opCode, ByteBuffer payload, boolean last, + SendHandler handler) { + + wsSession.updateLastActive(); + + List<MessagePart> messageParts = new ArrayList<>(); + messageParts.add(new MessagePart(last, 0, opCode, payload, + intermediateMessageHandler, + new EndMessageHandler(this, handler), -1)); + + messageParts = transformation.sendMessagePart(messageParts); + + // Some extensions/transformations may buffer messages so it is possible + // that no message parts will be returned. If this is the case the + // trigger the supplied SendHandler + if (messageParts.size() == 0) { + handler.onResult(new SendResult()); + return; + } + + MessagePart mp = messageParts.remove(0); + + boolean doWrite = false; + synchronized (messagePartLock) { + if (Constants.OPCODE_CLOSE == mp.getOpCode() && getBatchingAllowed()) { + // Should not happen. To late to send batched messages now since + // the session has been closed. Complain loudly. + log.warn(sm.getString("wsRemoteEndpoint.flushOnCloseFailed")); + } + if (messagePartInProgress.tryAcquire()) { + doWrite = true; + } else { + // When a control message is sent while another message is being + // sent, the control message is queued. Chances are the + // subsequent data message part will end up queued while the + // control message is sent. The logic in this class (state + // machine, EndMessageHandler, TextMessageSendHandler) ensures + // that there will only ever be one data message part in the + // queue. There could be multiple control messages in the queue. + + // Add it to the queue + messagePartQueue.add(mp); + } + // Add any remaining messages to the queue + messagePartQueue.addAll(messageParts); + } + if (doWrite) { + // Actual write has to be outside sync block to avoid possible + // deadlock between messagePartLock and writeLock in + // o.a.coyote.http11.upgrade.AbstractServletOutputStream + writeMessagePart(mp); + } + } + + + void endMessage(SendHandler handler, SendResult result) { + boolean doWrite = false; + MessagePart mpNext = null; + synchronized (messagePartLock) { + + fragmented = nextFragmented; + text = nextText; + + mpNext = messagePartQueue.poll(); + if (mpNext == null) { + messagePartInProgress.release(); + } else if (!closed){ + // Session may have been closed unexpectedly in the middle of + // sending a fragmented message closing the endpoint. If this + // happens, clearly there is no point trying to send the rest of + // the message. + doWrite = true; + } + } + if (doWrite) { + // Actual write has to be outside sync block to avoid possible + // deadlock between messagePartLock and writeLock in + // o.a.coyote.http11.upgrade.AbstractServletOutputStream + writeMessagePart(mpNext); + } + + wsSession.updateLastActive(); + + // Some handlers, such as the IntermediateMessageHandler, do not have a + // nested handler so handler may be null. + if (handler != null) { + handler.onResult(result); + } + } + + + void writeMessagePart(MessagePart mp) { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closed")); + } + + if (Constants.INTERNAL_OPCODE_FLUSH == mp.getOpCode()) { + nextFragmented = fragmented; + nextText = text; + outputBuffer.flip(); + SendHandler flushHandler = new OutputBufferFlushSendHandler( + outputBuffer, mp.getEndHandler()); + doWrite(flushHandler, mp.getBlockingWriteTimeoutExpiry(), outputBuffer); + return; + } + + // Control messages may be sent in the middle of fragmented message + // so they have no effect on the fragmented or text flags + boolean first; + if (Util.isControl(mp.getOpCode())) { + nextFragmented = fragmented; + nextText = text; + if (mp.getOpCode() == Constants.OPCODE_CLOSE) { + closed = true; + } + first = true; + } else { + boolean isText = Util.isText(mp.getOpCode()); + + if (fragmented) { + // Currently fragmented + if (text != isText) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.changeType")); + } + nextText = text; + nextFragmented = !mp.isFin(); + first = false; + } else { + // Wasn't fragmented. Might be now + if (mp.isFin()) { + nextFragmented = false; + } else { + nextFragmented = true; + nextText = isText; + } + first = true; + } + } + + byte[] mask; + + if (isMasked()) { + mask = Util.generateMask(); + } else { + mask = null; + } + + headerBuffer.clear(); + writeHeader(headerBuffer, mp.isFin(), mp.getRsv(), mp.getOpCode(), + isMasked(), mp.getPayload(), mask, first); + headerBuffer.flip(); + + if (getBatchingAllowed() || isMasked()) { + // Need to write via output buffer + OutputBufferSendHandler obsh = new OutputBufferSendHandler( + mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), + headerBuffer, mp.getPayload(), mask, + outputBuffer, !getBatchingAllowed(), this); + obsh.write(); + } else { + // Can write directly + doWrite(mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), + headerBuffer, mp.getPayload()); + } + } + + + private long getBlockingSendTimeout() { + Object obj = wsSession.getUserProperties().get(Constants.BLOCKING_SEND_TIMEOUT_PROPERTY); + Long userTimeout = null; + if (obj instanceof Long) { + userTimeout = (Long) obj; + } + if (userTimeout == null) { + return Constants.DEFAULT_BLOCKING_SEND_TIMEOUT; + } else { + return userTimeout.longValue(); + } + } + + + /** + * Wraps the user provided handler so that the end point is notified when + * the message is complete. + */ + private static class EndMessageHandler implements SendHandler { + + private final WsRemoteEndpointImplBase endpoint; + private final SendHandler handler; + + public EndMessageHandler(WsRemoteEndpointImplBase endpoint, + SendHandler handler) { + this.endpoint = endpoint; + this.handler = handler; + } + + + @Override + public void onResult(SendResult result) { + endpoint.endMessage(handler, result); + } + } + + + /** + * If a transformation needs to split a {@link MessagePart} into multiple + * {@link MessagePart}s, it uses this handler as the end handler for each of + * the additional {@link MessagePart}s. This handler notifies this this + * class that the {@link MessagePart} has been processed and that the next + * {@link MessagePart} in the queue should be started. The final + * {@link MessagePart} will use the {@link EndMessageHandler} provided with + * the original {@link MessagePart}. + */ + private static class IntermediateMessageHandler implements SendHandler { + + private final WsRemoteEndpointImplBase endpoint; + + public IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + + @Override + public void onResult(SendResult result) { + endpoint.endMessage(null, result); + } + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void sendObject(Object obj) throws IOException, EncodeException { + if (obj == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + /* + * Note that the implementation will convert primitives and their object + * equivalents by default but that users are free to specify their own + * encoders and decoders for this if they wish. + */ + Encoder encoder = findEncoder(obj); + if (encoder == null && Util.isPrimitive(obj.getClass())) { + String msg = obj.toString(); + sendString(msg); + return; + } + if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { + ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); + sendBytes(msg); + return; + } + + if (encoder instanceof Encoder.Text) { + String msg = ((Encoder.Text) encoder).encode(obj); + sendString(msg); + } else if (encoder instanceof Encoder.TextStream) { + try (Writer w = getSendWriter()) { + ((Encoder.TextStream) encoder).encode(obj, w); + } + } else if (encoder instanceof Encoder.Binary) { + ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); + sendBytes(msg); + } else if (encoder instanceof Encoder.BinaryStream) { + try (OutputStream os = getSendStream()) { + ((Encoder.BinaryStream) encoder).encode(obj, os); + } + } else { + throw new EncodeException(obj, sm.getString( + "wsRemoteEndpoint.noEncoder", obj.getClass())); + } + } + + + public Future<Void> sendObjectByFuture(Object obj) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendObjectByCompletion(obj, f2sh); + return f2sh; + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void sendObjectByCompletion(Object obj, SendHandler completion) { + + if (obj == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (completion == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + + /* + * Note that the implementation will convert primitives and their object + * equivalents by default but that users are free to specify their own + * encoders and decoders for this if they wish. + */ + Encoder encoder = findEncoder(obj); + if (encoder == null && Util.isPrimitive(obj.getClass())) { + String msg = obj.toString(); + sendStringByCompletion(msg, completion); + return; + } + if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { + ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); + sendBytesByCompletion(msg, completion); + return; + } + + try { + if (encoder instanceof Encoder.Text) { + String msg = ((Encoder.Text) encoder).encode(obj); + sendStringByCompletion(msg, completion); + } else if (encoder instanceof Encoder.TextStream) { + try (Writer w = getSendWriter()) { + ((Encoder.TextStream) encoder).encode(obj, w); + } + completion.onResult(new SendResult()); + } else if (encoder instanceof Encoder.Binary) { + ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); + sendBytesByCompletion(msg, completion); + } else if (encoder instanceof Encoder.BinaryStream) { + try (OutputStream os = getSendStream()) { + ((Encoder.BinaryStream) encoder).encode(obj, os); + } + completion.onResult(new SendResult()); + } else { + throw new EncodeException(obj, sm.getString( + "wsRemoteEndpoint.noEncoder", obj.getClass())); + } + } catch (Exception e) { + SendResult sr = new SendResult(e); + completion.onResult(sr); + } + } + + + protected void setSession(WsSession wsSession) { + this.wsSession = wsSession; + } + + + protected void setRequest(Request request) { + this.request = request; + } + + protected void setEncoders(EndpointConfig endpointConfig) + throws DeploymentException { + encoderEntries.clear(); + for (Class<? extends Encoder> encoderClazz : + endpointConfig.getEncoders()) { + Encoder instance; + try { + instance = encoderClazz.getConstructor().newInstance(); + instance.init(endpointConfig); + } catch (ReflectiveOperationException e) { + throw new DeploymentException( + sm.getString("wsRemoteEndpoint.invalidEncoder", + encoderClazz.getName()), e); + } + EncoderEntry entry = new EncoderEntry( + Util.getEncoderType(encoderClazz), instance); + encoderEntries.add(entry); + } + } + + + private Encoder findEncoder(Object obj) { + for (EncoderEntry entry : encoderEntries) { + if (entry.getClazz().isAssignableFrom(obj.getClass())) { + return entry.getEncoder(); + } + } + return null; + } + + + public final void close() { + for (EncoderEntry entry : encoderEntries) { + entry.getEncoder().destroy(); + } + + request.closeWs(); + } + + + protected abstract void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... data); + protected abstract boolean isMasked(); + protected abstract void doClose(); + + private static void writeHeader(ByteBuffer headerBuffer, boolean fin, + int rsv, byte opCode, boolean masked, ByteBuffer payload, + byte[] mask, boolean first) { + + byte b = 0; + + if (fin) { + // Set the fin bit + b -= 128; + } + + b += (rsv << 4); + + if (first) { + // This is the first fragment of this message + b += opCode; + } + // If not the first fragment, it is a continuation with opCode of zero + + headerBuffer.put(b); + + if (masked) { + b = (byte) 0x80; + } else { + b = 0; + } + + // Next write the mask && length length + if (payload.limit() < 126) { + headerBuffer.put((byte) (payload.limit() | b)); + } else if (payload.limit() < 65536) { + headerBuffer.put((byte) (126 | b)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } else { + // Will never be more than 2^31-1 + headerBuffer.put((byte) (127 | b)); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) (payload.limit() >>> 24)); + headerBuffer.put((byte) (payload.limit() >>> 16)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } + if (masked) { + headerBuffer.put(mask[0]); + headerBuffer.put(mask[1]); + headerBuffer.put(mask[2]); + headerBuffer.put(mask[3]); + } + } + + + private class TextMessageSendHandler implements SendHandler { + + private final SendHandler handler; + private final CharBuffer message; + private final boolean isLast; + private final CharsetEncoder encoder; + private final ByteBuffer buffer; + private final WsRemoteEndpointImplBase endpoint; + private volatile boolean isDone = false; + + public TextMessageSendHandler(SendHandler handler, CharBuffer message, + boolean isLast, CharsetEncoder encoder, + ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint) { + this.handler = handler; + this.message = message; + this.isLast = isLast; + this.encoder = encoder.reset(); + this.buffer = encoderBuffer; + this.endpoint = endpoint; + } + + public void write() { + buffer.clear(); + CoderResult cr = encoder.encode(message, buffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + buffer.flip(); + endpoint.startMessage(Constants.OPCODE_TEXT, buffer, + isDone && isLast, this); + } + + @Override + public void onResult(SendResult result) { + if (isDone) { + endpoint.stateMachine.complete(isLast); + handler.onResult(result); + } else if(!result.isOK()) { + handler.onResult(result); + } else if (closed){ + SendResult sr = new SendResult(new IOException( + sm.getString("wsRemoteEndpoint.closedDuringMessage"))); + handler.onResult(sr); + } else { + write(); + } + } + } + + + /** + * Used to write data to the output buffer, flushing the buffer if it fills + * up. + */ + private static class OutputBufferSendHandler implements SendHandler { + + private final SendHandler handler; + private final long blockingWriteTimeoutExpiry; + private final ByteBuffer headerBuffer; + private final ByteBuffer payload; + private final byte[] mask; + private final ByteBuffer outputBuffer; + private final boolean flushRequired; + private final WsRemoteEndpointImplBase endpoint; + private int maskIndex = 0; + + public OutputBufferSendHandler(SendHandler completion, + long blockingWriteTimeoutExpiry, + ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask, + ByteBuffer outputBuffer, boolean flushRequired, + WsRemoteEndpointImplBase endpoint) { + this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry; + this.handler = completion; + this.headerBuffer = headerBuffer; + this.payload = payload; + this.mask = mask; + this.outputBuffer = outputBuffer; + this.flushRequired = flushRequired; + this.endpoint = endpoint; + } + + public void write() { + // Write the header + while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) { + outputBuffer.put(headerBuffer.get()); + } + if (headerBuffer.hasRemaining()) { + // Still more headers to write, need to flush + outputBuffer.flip(); + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + return; + } + + // Write the payload + int payloadLeft = payload.remaining(); + int payloadLimit = payload.limit(); + int outputSpace = outputBuffer.remaining(); + int toWrite = payloadLeft; + + if (payloadLeft > outputSpace) { + toWrite = outputSpace; + // Temporarily reduce the limit + payload.limit(payload.position() + toWrite); + } + + if (mask == null) { + // Use a bulk copy + outputBuffer.put(payload); + } else { + for (int i = 0; i < toWrite; i++) { + outputBuffer.put( + (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF))); + if (maskIndex > 3) { + maskIndex = 0; + } + } + } + + if (payloadLeft > outputSpace) { + // Restore the original limit + payload.limit(payloadLimit); + // Still more data to write, need to flush + outputBuffer.flip(); + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + return; + } + + if (flushRequired) { + outputBuffer.flip(); + if (outputBuffer.remaining() == 0) { + handler.onResult(SENDRESULT_OK); + } else { + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + } + } else { + handler.onResult(SENDRESULT_OK); + } + } + + // ------------------------------------------------- SendHandler methods + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + if (outputBuffer.hasRemaining()) { + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + } else { + outputBuffer.clear(); + write(); + } + } else { + handler.onResult(result); + } + } + } + + + /** + * Ensures that the output buffer is cleared after it has been flushed. + */ + private static class OutputBufferFlushSendHandler implements SendHandler { + + private final ByteBuffer outputBuffer; + private final SendHandler handler; + + public OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler) { + this.outputBuffer = outputBuffer; + this.handler = handler; + } + + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + outputBuffer.clear(); + } + handler.onResult(result); + } + } + + + private static class WsOutputStream extends OutputStream { + + private final WsRemoteEndpointImplBase endpoint; + private final ByteBuffer buffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Object closeLock = new Object(); + private volatile boolean closed = false; + private volatile boolean used = false; + + public WsOutputStream(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + @Override + public void write(int b) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + buffer.put((byte) b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + if (len == 0) { + return; + } + if ((off < 0) || (off > b.length) || (len < 0) || + ((off + len) > b.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + int remaining = buffer.remaining(); + int written = 0; + + while (remaining < len - written) { + buffer.put(b, off + written, remaining); + written += remaining; + flush(); + remaining = buffer.remaining(); + } + buffer.put(b, off + written, len - written); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + + // Optimisation. If there is no data to flush then do not send an + // empty message. + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { + doWrite(false); + } + } + + @Override + public void close() throws IOException { + synchronized (closeLock) { + if (closed) { + return; + } + closed = true; + } + + doWrite(true); + } + + private void doWrite(boolean last) throws IOException { + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { + buffer.flip(); + endpoint.sendMessageBlock(Constants.OPCODE_BINARY, buffer, last); + } + endpoint.stateMachine.complete(last); + buffer.clear(); + } + } + + + private static class WsWriter extends Writer { + + private final WsRemoteEndpointImplBase endpoint; + private final CharBuffer buffer = CharBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Object closeLock = new Object(); + private volatile boolean closed = false; + private volatile boolean used = false; + + public WsWriter(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + @Override + public void write(char[] cbuf, int off, int len) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedWriter")); + } + if (len == 0) { + return; + } + if ((off < 0) || (off > cbuf.length) || (len < 0) || + ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + int remaining = buffer.remaining(); + int written = 0; + + while (remaining < len - written) { + buffer.put(cbuf, off + written, remaining); + written += remaining; + flush(); + remaining = buffer.remaining(); + } + buffer.put(cbuf, off + written, len - written); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedWriter")); + } + + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { + doWrite(false); + } + } + + @Override + public void close() throws IOException { + synchronized (closeLock) { + if (closed) { + return; + } + closed = true; + } + + doWrite(true); + } + + private void doWrite(boolean last) throws IOException { + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { + buffer.flip(); + endpoint.sendMessageBlock(buffer, last); + buffer.clear(); + } else { + endpoint.stateMachine.complete(last); + } + } + } + + + private static class EncoderEntry { + + private final Class<?> clazz; + private final Encoder encoder; + + public EncoderEntry(Class<?> clazz, Encoder encoder) { + this.clazz = clazz; + this.encoder = encoder; + } + + public Class<?> getClazz() { + return clazz; + } + + public Encoder getEncoder() { + return encoder; + } + } + + + private enum State { + OPEN, + STREAM_WRITING, + WRITER_WRITING, + BINARY_PARTIAL_WRITING, + BINARY_PARTIAL_READY, + BINARY_FULL_WRITING, + TEXT_PARTIAL_WRITING, + TEXT_PARTIAL_READY, + TEXT_FULL_WRITING + } + + + private static class StateMachine { + private State state = State.OPEN; + + public synchronized void streamStart() { + checkState(State.OPEN); + state = State.STREAM_WRITING; + } + + public synchronized void writeStart() { + checkState(State.OPEN); + state = State.WRITER_WRITING; + } + + public synchronized void binaryPartialStart() { + checkState(State.OPEN, State.BINARY_PARTIAL_READY); + state = State.BINARY_PARTIAL_WRITING; + } + + public synchronized void binaryStart() { + checkState(State.OPEN); + state = State.BINARY_FULL_WRITING; + } + + public synchronized void textPartialStart() { + checkState(State.OPEN, State.TEXT_PARTIAL_READY); + state = State.TEXT_PARTIAL_WRITING; + } + + public synchronized void textStart() { + checkState(State.OPEN); + state = State.TEXT_FULL_WRITING; + } + + public synchronized void complete(boolean last) { + if (last) { + checkState(State.TEXT_PARTIAL_WRITING, State.TEXT_FULL_WRITING, + State.BINARY_PARTIAL_WRITING, State.BINARY_FULL_WRITING, + State.STREAM_WRITING, State.WRITER_WRITING); + state = State.OPEN; + } else { + checkState(State.TEXT_PARTIAL_WRITING, State.BINARY_PARTIAL_WRITING, + State.STREAM_WRITING, State.WRITER_WRITING); + if (state == State.TEXT_PARTIAL_WRITING) { + state = State.TEXT_PARTIAL_READY; + } else if (state == State.BINARY_PARTIAL_WRITING){ + state = State.BINARY_PARTIAL_READY; + } else if (state == State.WRITER_WRITING) { + // NO-OP. Leave state as is. + } else if (state == State.STREAM_WRITING) { + // NO-OP. Leave state as is. + } else { + // Should never happen + // The if ... else ... blocks above should cover all states + // permitted by the preceding checkState() call + throw new IllegalStateException( + "BUG: This code should never be called"); + } + } + } + + private void checkState(State... required) { + for (State state : required) { + if (this.state == state) { + return; + } + } + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.wrongState", this.state)); + } + } + + + private static class StateUpdateSendHandler implements SendHandler { + + private final SendHandler handler; + private final StateMachine stateMachine; + + public StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine) { + this.handler = handler; + this.stateMachine = stateMachine; + } + + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + stateMachine.complete(true); + } + handler.onResult(result); + } + } + + + private static class BlockingSendHandler implements SendHandler { + + private SendResult sendResult = null; + + @Override + public void onResult(SendResult result) { + sendResult = result; + } + + public SendResult getSendResult() { + return sendResult; + } + } +} |