diff options
Diffstat (limited to '')
146 files changed, 17643 insertions, 427 deletions
diff --git a/src/go/unit/nxt_cgo_lib.c b/src/go/unit/nxt_cgo_lib.c index cc1228f5..5cb31b5a 100644 --- a/src/go/unit/nxt_cgo_lib.c +++ b/src/go/unit/nxt_cgo_lib.c @@ -6,7 +6,6 @@ #include "_cgo_export.h" -#include <nxt_main.h> #include <nxt_unit.h> #include <nxt_unit_request.h> @@ -39,7 +38,7 @@ nxt_cgo_run(uintptr_t handler) init.data = (void *) handler; ctx = nxt_unit_init(&init); - if (nxt_slow_path(ctx == NULL)) { + if (ctx == NULL) { return NXT_UNIT_ERROR; } @@ -171,7 +170,7 @@ nxt_cgo_response_write(uintptr_t req, uintptr_t start, uint32_t len) rc = nxt_unit_response_write((nxt_unit_request_info_t *) req, (void *) start, len); - if (nxt_slow_path(rc != NXT_UNIT_OK)) { + if (rc != NXT_UNIT_OK) { return -1; } diff --git a/src/go/unit/request.go b/src/go/unit/request.go index ad56cabb..1d8c6702 100644 --- a/src/go/unit/request.go +++ b/src/go/unit/request.go @@ -135,7 +135,7 @@ func nxt_go_request_set_tls(go_req uintptr) { //export nxt_go_request_handler func nxt_go_request_handler(go_req uintptr, h uintptr) { r := get_request(go_req) - handler := *(*http.Handler)(unsafe.Pointer(h)) + handler := get_handler(h) go func(r *request) { handler.ServeHTTP(r.response(), &r.req) diff --git a/src/go/unit/unit.go b/src/go/unit/unit.go index 06257768..1534479e 100644 --- a/src/go/unit/unit.go +++ b/src/go/unit/unit.go @@ -13,6 +13,7 @@ import "C" import ( "fmt" "net/http" + "sync" "unsafe" ) @@ -87,12 +88,58 @@ func nxt_go_warn(format string, args ...interface{}) { C.nxt_cgo_warn(str_ref(str), C.uint32_t(len(str))) } +type handler_registry struct { + sync.RWMutex + next uintptr + m map[uintptr]*http.Handler +} + +var handler_registry_ handler_registry + +func set_handler(handler *http.Handler) uintptr { + + handler_registry_.Lock() + if handler_registry_.m == nil { + handler_registry_.m = make(map[uintptr]*http.Handler) + handler_registry_.next = 1 + } + + h := handler_registry_.next + handler_registry_.next += 1 + handler_registry_.m[h] = handler + + handler_registry_.Unlock() + + return h +} + +func get_handler(h uintptr) http.Handler { + handler_registry_.RLock() + defer handler_registry_.RUnlock() + + return *handler_registry_.m[h] +} + +func reset_handler(h uintptr) { + + handler_registry_.Lock() + if handler_registry_.m != nil { + delete(handler_registry_.m, h) + } + + handler_registry_.Unlock() +} + func ListenAndServe(addr string, handler http.Handler) error { if handler == nil { handler = http.DefaultServeMux } - rc := C.nxt_cgo_run(C.uintptr_t(uintptr(unsafe.Pointer(&handler)))) + h := set_handler(&handler) + + rc := C.nxt_cgo_run(C.uintptr_t(h)) + + reset_handler(h) if rc != 0 { return http.ListenAndServe(addr, handler) diff --git a/src/java/javax/websocket/ClientEndpoint.java b/src/java/javax/websocket/ClientEndpoint.java new file mode 100644 index 00000000..ee984171 --- /dev/null +++ b/src/java/javax/websocket/ClientEndpoint.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import javax.websocket.ClientEndpointConfig.Configurator; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ClientEndpoint { + String[] subprotocols() default {}; + Class<? extends Decoder>[] decoders() default {}; + Class<? extends Encoder>[] encoders() default {}; + public Class<? extends Configurator> configurator() + default Configurator.class; +} diff --git a/src/java/javax/websocket/ClientEndpointConfig.java b/src/java/javax/websocket/ClientEndpointConfig.java new file mode 100644 index 00000000..13b6cba5 --- /dev/null +++ b/src/java/javax/websocket/ClientEndpointConfig.java @@ -0,0 +1,138 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +public interface ClientEndpointConfig extends EndpointConfig { + + List<String> getPreferredSubprotocols(); + + List<Extension> getExtensions(); + + public Configurator getConfigurator(); + + public final class Builder { + + private static final Configurator DEFAULT_CONFIGURATOR = + new Configurator() {}; + + + public static Builder create() { + return new Builder(); + } + + + private Builder() { + // Hide default constructor + } + + private Configurator configurator = DEFAULT_CONFIGURATOR; + private List<String> preferredSubprotocols = Collections.emptyList(); + private List<Extension> extensions = Collections.emptyList(); + private List<Class<? extends Encoder>> encoders = + Collections.emptyList(); + private List<Class<? extends Decoder>> decoders = + Collections.emptyList(); + + + public ClientEndpointConfig build() { + return new DefaultClientEndpointConfig(preferredSubprotocols, + extensions, encoders, decoders, configurator); + } + + + public Builder configurator(Configurator configurator) { + if (configurator == null) { + this.configurator = DEFAULT_CONFIGURATOR; + } else { + this.configurator = configurator; + } + return this; + } + + + public Builder preferredSubprotocols( + List<String> preferredSubprotocols) { + if (preferredSubprotocols == null || + preferredSubprotocols.size() == 0) { + this.preferredSubprotocols = Collections.emptyList(); + } else { + this.preferredSubprotocols = + Collections.unmodifiableList(preferredSubprotocols); + } + return this; + } + + + public Builder extensions( + List<Extension> extensions) { + if (extensions == null || extensions.size() == 0) { + this.extensions = Collections.emptyList(); + } else { + this.extensions = Collections.unmodifiableList(extensions); + } + return this; + } + + + public Builder encoders(List<Class<? extends Encoder>> encoders) { + if (encoders == null || encoders.size() == 0) { + this.encoders = Collections.emptyList(); + } else { + this.encoders = Collections.unmodifiableList(encoders); + } + return this; + } + + + public Builder decoders(List<Class<? extends Decoder>> decoders) { + if (decoders == null || decoders.size() == 0) { + this.decoders = Collections.emptyList(); + } else { + this.decoders = Collections.unmodifiableList(decoders); + } + return this; + } + } + + + public class Configurator { + + /** + * Provides the client with a mechanism to inspect and/or modify the headers + * that are sent to the server to start the WebSocket handshake. + * + * @param headers The HTTP headers + */ + public void beforeRequest(Map<String, List<String>> headers) { + // NO-OP + } + + /** + * Provides the client with a mechanism to inspect the handshake response + * that is returned from the server. + * + * @param handshakeResponse The response + */ + public void afterResponse(HandshakeResponse handshakeResponse) { + // NO-OP + } + } +} diff --git a/src/java/javax/websocket/CloseReason.java b/src/java/javax/websocket/CloseReason.java new file mode 100644 index 00000000..ef88d135 --- /dev/null +++ b/src/java/javax/websocket/CloseReason.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public class CloseReason { + + private final CloseCode closeCode; + private final String reasonPhrase; + + public CloseReason(CloseReason.CloseCode closeCode, String reasonPhrase) { + this.closeCode = closeCode; + this.reasonPhrase = reasonPhrase; + } + + public CloseCode getCloseCode() { + return closeCode; + } + + public String getReasonPhrase() { + return reasonPhrase; + } + + @Override + public String toString() { + return "CloseReason: code [" + closeCode.getCode() + + "], reason [" + reasonPhrase + "]"; + } + + public interface CloseCode { + int getCode(); + } + + public enum CloseCodes implements CloseReason.CloseCode { + + NORMAL_CLOSURE(1000), + GOING_AWAY(1001), + PROTOCOL_ERROR(1002), + CANNOT_ACCEPT(1003), + RESERVED(1004), + NO_STATUS_CODE(1005), + CLOSED_ABNORMALLY(1006), + NOT_CONSISTENT(1007), + VIOLATED_POLICY(1008), + TOO_BIG(1009), + NO_EXTENSION(1010), + UNEXPECTED_CONDITION(1011), + SERVICE_RESTART(1012), + TRY_AGAIN_LATER(1013), + TLS_HANDSHAKE_FAILURE(1015); + + private int code; + + CloseCodes(int code) { + this.code = code; + } + + public static CloseCode getCloseCode(final int code) { + if (code > 2999 && code < 5000) { + return new CloseCode() { + @Override + public int getCode() { + return code; + } + }; + } + switch (code) { + case 1000: + return CloseCodes.NORMAL_CLOSURE; + case 1001: + return CloseCodes.GOING_AWAY; + case 1002: + return CloseCodes.PROTOCOL_ERROR; + case 1003: + return CloseCodes.CANNOT_ACCEPT; + case 1004: + return CloseCodes.RESERVED; + case 1005: + return CloseCodes.NO_STATUS_CODE; + case 1006: + return CloseCodes.CLOSED_ABNORMALLY; + case 1007: + return CloseCodes.NOT_CONSISTENT; + case 1008: + return CloseCodes.VIOLATED_POLICY; + case 1009: + return CloseCodes.TOO_BIG; + case 1010: + return CloseCodes.NO_EXTENSION; + case 1011: + return CloseCodes.UNEXPECTED_CONDITION; + case 1012: + return CloseCodes.SERVICE_RESTART; + case 1013: + return CloseCodes.TRY_AGAIN_LATER; + case 1015: + return CloseCodes.TLS_HANDSHAKE_FAILURE; + default: + throw new IllegalArgumentException( + "Invalid close code: [" + code + "]"); + } + } + + @Override + public int getCode() { + return code; + } + } +} diff --git a/src/java/javax/websocket/ContainerProvider.java b/src/java/javax/websocket/ContainerProvider.java new file mode 100644 index 00000000..1727ca93 --- /dev/null +++ b/src/java/javax/websocket/ContainerProvider.java @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.Iterator; +import java.util.ServiceLoader; + +/** + * Use the {@link ServiceLoader} mechanism to provide instances of the WebSocket + * client container. + */ +public abstract class ContainerProvider { + + private static final String DEFAULT_PROVIDER_CLASS_NAME = + "nginx.unit.websocket.WsWebSocketContainer"; + + /** + * Create a new container used to create outgoing WebSocket connections. + * + * @return A newly created container. + */ + public static WebSocketContainer getWebSocketContainer() { + WebSocketContainer result = null; + + ServiceLoader<ContainerProvider> serviceLoader = + ServiceLoader.load(ContainerProvider.class); + Iterator<ContainerProvider> iter = serviceLoader.iterator(); + while (result == null && iter.hasNext()) { + result = iter.next().getContainer(); + } + + // Fall-back. Also used by unit tests + if (result == null) { + try { + @SuppressWarnings("unchecked") + Class<WebSocketContainer> clazz = + (Class<WebSocketContainer>) Class.forName( + DEFAULT_PROVIDER_CLASS_NAME); + result = clazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException | IllegalArgumentException | + SecurityException e) { + // No options left. Just return null. + } + } + return result; + } + + protected abstract WebSocketContainer getContainer(); +} diff --git a/src/java/javax/websocket/DecodeException.java b/src/java/javax/websocket/DecodeException.java new file mode 100644 index 00000000..771cfa58 --- /dev/null +++ b/src/java/javax/websocket/DecodeException.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.nio.ByteBuffer; + +public class DecodeException extends Exception { + + private static final long serialVersionUID = 1L; + + private ByteBuffer bb; + private String encodedString; + + public DecodeException(ByteBuffer bb, String message, Throwable cause) { + super(message, cause); + this.bb = bb; + } + + public DecodeException(String encodedString, String message, + Throwable cause) { + super(message, cause); + this.encodedString = encodedString; + } + + public DecodeException(ByteBuffer bb, String message) { + super(message); + this.bb = bb; + } + + public DecodeException(String encodedString, String message) { + super(message); + this.encodedString = encodedString; + } + + public ByteBuffer getBytes() { + return bb; + } + + public String getText() { + return encodedString; + } +} diff --git a/src/java/javax/websocket/Decoder.java b/src/java/javax/websocket/Decoder.java new file mode 100644 index 00000000..fad262e3 --- /dev/null +++ b/src/java/javax/websocket/Decoder.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.IOException; +import java.io.InputStream; +import java.io.Reader; +import java.nio.ByteBuffer; + +public interface Decoder { + + abstract void init(EndpointConfig endpointConfig); + + abstract void destroy(); + + interface Binary<T> extends Decoder { + + T decode(ByteBuffer bytes) throws DecodeException; + + boolean willDecode(ByteBuffer bytes); + } + + interface BinaryStream<T> extends Decoder { + + T decode(InputStream is) throws DecodeException, IOException; + } + + interface Text<T> extends Decoder { + + T decode(String s) throws DecodeException; + + boolean willDecode(String s); + } + + interface TextStream<T> extends Decoder { + + T decode(Reader reader) throws DecodeException, IOException; + } +} diff --git a/src/java/javax/websocket/DefaultClientEndpointConfig.java b/src/java/javax/websocket/DefaultClientEndpointConfig.java new file mode 100644 index 00000000..ce28cb26 --- /dev/null +++ b/src/java/javax/websocket/DefaultClientEndpointConfig.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +final class DefaultClientEndpointConfig implements ClientEndpointConfig { + + private final List<String> preferredSubprotocols; + private final List<Extension> extensions; + private final List<Class<? extends Encoder>> encoders; + private final List<Class<? extends Decoder>> decoders; + private final Map<String,Object> userProperties = new ConcurrentHashMap<>(); + private final Configurator configurator; + + + DefaultClientEndpointConfig(List<String> preferredSubprotocols, + List<Extension> extensions, + List<Class<? extends Encoder>> encoders, + List<Class<? extends Decoder>> decoders, + Configurator configurator) { + this.preferredSubprotocols = preferredSubprotocols; + this.extensions = extensions; + this.decoders = decoders; + this.encoders = encoders; + this.configurator = configurator; + } + + + @Override + public List<String> getPreferredSubprotocols() { + return preferredSubprotocols; + } + + + @Override + public List<Extension> getExtensions() { + return extensions; + } + + + @Override + public List<Class<? extends Encoder>> getEncoders() { + return encoders; + } + + + @Override + public List<Class<? extends Decoder>> getDecoders() { + return decoders; + } + + + @Override + public final Map<String, Object> getUserProperties() { + return userProperties; + } + + + @Override + public Configurator getConfigurator() { + return configurator; + } +} diff --git a/src/java/javax/websocket/DeploymentException.java b/src/java/javax/websocket/DeploymentException.java new file mode 100644 index 00000000..1678fd09 --- /dev/null +++ b/src/java/javax/websocket/DeploymentException.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public class DeploymentException extends Exception { + + private static final long serialVersionUID = 1L; + + public DeploymentException(String message) { + super(message); + } + + public DeploymentException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/java/javax/websocket/EncodeException.java b/src/java/javax/websocket/EncodeException.java new file mode 100644 index 00000000..fdb536ac --- /dev/null +++ b/src/java/javax/websocket/EncodeException.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public class EncodeException extends Exception { + + private static final long serialVersionUID = 1L; + + private Object object; + + public EncodeException(Object object, String message) { + super(message); + this.object = object; + } + + public EncodeException(Object object, String message, Throwable cause) { + super(message, cause); + this.object = object; + } + + public Object getObject() { + return this.object; + } +} diff --git a/src/java/javax/websocket/Encoder.java b/src/java/javax/websocket/Encoder.java new file mode 100644 index 00000000..42a107f0 --- /dev/null +++ b/src/java/javax/websocket/Encoder.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.nio.ByteBuffer; + +public interface Encoder { + + abstract void init(EndpointConfig endpointConfig); + + abstract void destroy(); + + interface Text<T> extends Encoder { + + String encode(T object) throws EncodeException; + } + + interface TextStream<T> extends Encoder { + + void encode(T object, Writer writer) + throws EncodeException, IOException; + } + + interface Binary<T> extends Encoder { + + ByteBuffer encode(T object) throws EncodeException; + } + + interface BinaryStream<T> extends Encoder { + + void encode(T object, OutputStream os) + throws EncodeException, IOException; + } +} diff --git a/src/java/javax/websocket/Endpoint.java b/src/java/javax/websocket/Endpoint.java new file mode 100644 index 00000000..9dfdbcce --- /dev/null +++ b/src/java/javax/websocket/Endpoint.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public abstract class Endpoint { + + /** + * Event that is triggered when a new session starts. + * + * @param session The new session. + * @param config The configuration with which the Endpoint was + * configured. + */ + public abstract void onOpen(Session session, EndpointConfig config); + + /** + * Event that is triggered when a session has closed. + * + * @param session The session + * @param closeReason Why the session was closed + */ + public void onClose(Session session, CloseReason closeReason) { + // NO-OP by default + } + + /** + * Event that is triggered when a protocol error occurs. + * + * @param session The session. + * @param throwable The exception. + */ + public void onError(Session session, Throwable throwable) { + // NO-OP by default + } +} diff --git a/src/java/javax/websocket/EndpointConfig.java b/src/java/javax/websocket/EndpointConfig.java new file mode 100644 index 00000000..0b6c9681 --- /dev/null +++ b/src/java/javax/websocket/EndpointConfig.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.List; +import java.util.Map; + +public interface EndpointConfig { + + List<Class<? extends Encoder>> getEncoders(); + + List<Class<? extends Decoder>> getDecoders(); + + Map<String,Object> getUserProperties(); +} diff --git a/src/java/javax/websocket/Extension.java b/src/java/javax/websocket/Extension.java new file mode 100644 index 00000000..b95b27b8 --- /dev/null +++ b/src/java/javax/websocket/Extension.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.List; + +public interface Extension { + String getName(); + List<Parameter> getParameters(); + + interface Parameter { + String getName(); + String getValue(); + } +} diff --git a/src/java/javax/websocket/HandshakeResponse.java b/src/java/javax/websocket/HandshakeResponse.java new file mode 100644 index 00000000..807192e8 --- /dev/null +++ b/src/java/javax/websocket/HandshakeResponse.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.util.List; +import java.util.Map; + +public interface HandshakeResponse { + + /** + * Name of the WebSocket accept HTTP header. + */ + public static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; + + Map<String,List<String>> getHeaders(); +} diff --git a/src/java/javax/websocket/MessageHandler.java b/src/java/javax/websocket/MessageHandler.java new file mode 100644 index 00000000..2c30d997 --- /dev/null +++ b/src/java/javax/websocket/MessageHandler.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public interface MessageHandler { + + interface Partial<T> extends MessageHandler { + + /** + * Called when part of a message is available to be processed. + * + * @param messagePart The message part + * @param last <code>true</code> if this is the last part of + * this message, else <code>false</code> + */ + void onMessage(T messagePart, boolean last); + } + + interface Whole<T> extends MessageHandler { + + /** + * Called when a whole message is available to be processed. + * + * @param message The message + */ + void onMessage(T message); + } +} diff --git a/src/java/javax/websocket/OnClose.java b/src/java/javax/websocket/OnClose.java new file mode 100644 index 00000000..6ee61d36 --- /dev/null +++ b/src/java/javax/websocket/OnClose.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface OnClose { +} diff --git a/src/java/javax/websocket/OnError.java b/src/java/javax/websocket/OnError.java new file mode 100644 index 00000000..ce431484 --- /dev/null +++ b/src/java/javax/websocket/OnError.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface OnError { +} diff --git a/src/java/javax/websocket/OnMessage.java b/src/java/javax/websocket/OnMessage.java new file mode 100644 index 00000000..564fa994 --- /dev/null +++ b/src/java/javax/websocket/OnMessage.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface OnMessage { + long maxMessageSize() default -1; +} diff --git a/src/java/javax/websocket/OnOpen.java b/src/java/javax/websocket/OnOpen.java new file mode 100644 index 00000000..9f0ea6e3 --- /dev/null +++ b/src/java/javax/websocket/OnOpen.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.METHOD) +public @interface OnOpen { +} diff --git a/src/java/javax/websocket/PongMessage.java b/src/java/javax/websocket/PongMessage.java new file mode 100644 index 00000000..7e9e3b6a --- /dev/null +++ b/src/java/javax/websocket/PongMessage.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.nio.ByteBuffer; + +/** + * Represents a WebSocket Pong message and used by message handlers to enable + * applications to process the response to any Pings they send. + */ +public interface PongMessage { + /** + * Get the payload of the Pong message. + * + * @return The payload of the Pong message. + */ + ByteBuffer getApplicationData(); +} diff --git a/src/java/javax/websocket/RemoteEndpoint.java b/src/java/javax/websocket/RemoteEndpoint.java new file mode 100644 index 00000000..19c7a100 --- /dev/null +++ b/src/java/javax/websocket/RemoteEndpoint.java @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.nio.ByteBuffer; +import java.util.concurrent.Future; + + +public interface RemoteEndpoint { + + interface Async extends RemoteEndpoint { + + /** + * Obtain the timeout (in milliseconds) for sending a message + * asynchronously. The default value is determined by + * {@link WebSocketContainer#getDefaultAsyncSendTimeout()}. + * @return The current send timeout in milliseconds. A non-positive + * value means an infinite timeout. + */ + long getSendTimeout(); + + /** + * Set the timeout (in milliseconds) for sending a message + * asynchronously. The default value is determined by + * {@link WebSocketContainer#getDefaultAsyncSendTimeout()}. + * @param timeout The new timeout for sending messages asynchronously + * in milliseconds. A non-positive value means an + * infinite timeout. + */ + void setSendTimeout(long timeout); + + /** + * Send the message asynchronously, using the SendHandler to signal to the + * client when the message has been sent. + * @param text The text message to send + * @param completion Used to signal to the client when the message has + * been sent + */ + void sendText(String text, SendHandler completion); + + /** + * Send the message asynchronously, using the Future to signal to the + * client when the message has been sent. + * @param text The text message to send + * @return A Future that signals when the message has been sent. + */ + Future<Void> sendText(String text); + + /** + * Send the message asynchronously, using the Future to signal to the client + * when the message has been sent. + * @param data The text message to send + * @return A Future that signals when the message has been sent. + * @throws IllegalArgumentException if {@code data} is {@code null}. + */ + Future<Void> sendBinary(ByteBuffer data); + + /** + * Send the message asynchronously, using the SendHandler to signal to the + * client when the message has been sent. + * @param data The text message to send + * @param completion Used to signal to the client when the message has + * been sent + * @throws IllegalArgumentException if {@code data} or {@code completion} + * is {@code null}. + */ + void sendBinary(ByteBuffer data, SendHandler completion); + + /** + * Encodes object as a message and sends it asynchronously, using the + * Future to signal to the client when the message has been sent. + * @param obj The object to be sent. + * @return A Future that signals when the message has been sent. + * @throws IllegalArgumentException if {@code obj} is {@code null}. + */ + Future<Void> sendObject(Object obj); + + /** + * Encodes object as a message and sends it asynchronously, using the + * SendHandler to signal to the client when the message has been sent. + * @param obj The object to be sent. + * @param completion Used to signal to the client when the message has + * been sent + * @throws IllegalArgumentException if {@code obj} or + * {@code completion} is {@code null}. + */ + void sendObject(Object obj, SendHandler completion); + + } + + interface Basic extends RemoteEndpoint { + + /** + * Send the message, blocking until the message is sent. + * @param text The text message to send. + * @throws IllegalArgumentException if {@code text} is {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendText(String text) throws IOException; + + /** + * Send the message, blocking until the message is sent. + * @param data The binary message to send + * @throws IllegalArgumentException if {@code data} is {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendBinary(ByteBuffer data) throws IOException; + + /** + * Sends part of a text message to the remote endpoint. Once the first part + * of a message has been sent, no other text or binary messages may be sent + * until all remaining parts of this message have been sent. + * + * @param fragment The partial message to send + * @param isLast <code>true</code> if this is the last part of the + * message, otherwise <code>false</code> + * @throws IllegalArgumentException if {@code fragment} is {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendText(String fragment, boolean isLast) throws IOException; + + /** + * Sends part of a binary message to the remote endpoint. Once the first + * part of a message has been sent, no other text or binary messages may be + * sent until all remaining parts of this message have been sent. + * + * @param partialByte The partial message to send + * @param isLast <code>true</code> if this is the last part of the + * message, otherwise <code>false</code> + * @throws IllegalArgumentException if {@code partialByte} is + * {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendBinary(ByteBuffer partialByte, boolean isLast) throws IOException; + + OutputStream getSendStream() throws IOException; + + Writer getSendWriter() throws IOException; + + /** + * Encodes object as a message and sends it to the remote endpoint. + * @param data The object to be sent. + * @throws EncodeException if there was a problem encoding the + * {@code data} object as a websocket message. + * @throws IllegalArgumentException if {@code data} is {@code null}. + * @throws IOException if an I/O error occurs during the sending of the + * message. + */ + void sendObject(Object data) throws IOException, EncodeException; + + } + /** + * Enable or disable the batching of outgoing messages for this endpoint. If + * batching is disabled when it was previously enabled then this method will + * block until any currently batched messages have been written. + * + * @param batchingAllowed New setting + * @throws IOException If changing the value resulted in a call to + * {@link #flushBatch()} and that call threw an + * {@link IOException}. + */ + void setBatchingAllowed(boolean batchingAllowed) throws IOException; + + /** + * Obtains the current batching status of the endpoint. + * + * @return <code>true</code> if batching is enabled, otherwise + * <code>false</code>. + */ + boolean getBatchingAllowed(); + + /** + * Flush any currently batched messages to the remote endpoint. This method + * will block until the flush completes. + * + * @throws IOException If an I/O error occurs while flushing + */ + void flushBatch() throws IOException; + + /** + * Send a ping message blocking until the message has been sent. Note that + * if a message is in the process of being sent asynchronously, this method + * will block until that message and this ping has been sent. + * + * @param applicationData The payload for the ping message + * + * @throws IOException If an I/O error occurs while sending the ping + * @throws IllegalArgumentException if the applicationData is too large for + * a control message (max 125 bytes) + */ + void sendPing(ByteBuffer applicationData) + throws IOException, IllegalArgumentException; + + /** + * Send a pong message blocking until the message has been sent. Note that + * if a message is in the process of being sent asynchronously, this method + * will block until that message and this pong has been sent. + * + * @param applicationData The payload for the pong message + * + * @throws IOException If an I/O error occurs while sending the pong + * @throws IllegalArgumentException if the applicationData is too large for + * a control message (max 125 bytes) + */ + void sendPong(ByteBuffer applicationData) + throws IOException, IllegalArgumentException; +} + diff --git a/src/java/javax/websocket/SendHandler.java b/src/java/javax/websocket/SendHandler.java new file mode 100644 index 00000000..65b9a19a --- /dev/null +++ b/src/java/javax/websocket/SendHandler.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public interface SendHandler { + + void onResult(SendResult result); +} diff --git a/src/java/javax/websocket/SendResult.java b/src/java/javax/websocket/SendResult.java new file mode 100644 index 00000000..a3797d5b --- /dev/null +++ b/src/java/javax/websocket/SendResult.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public final class SendResult { + private final Throwable exception; + private final boolean ok; + + public SendResult(Throwable exception) { + this.exception = exception; + this.ok = (exception == null); + } + + public SendResult() { + this (null); + } + + public Throwable getException() { + return exception; + } + + public boolean isOK() { + return ok; + } +} diff --git a/src/java/javax/websocket/Session.java b/src/java/javax/websocket/Session.java new file mode 100644 index 00000000..eea15e5b --- /dev/null +++ b/src/java/javax/websocket/Session.java @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.Closeable; +import java.io.IOException; +import java.net.URI; +import java.security.Principal; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public interface Session extends Closeable { + + /** + * Get the container that created this session. + * @return the container that created this session. + */ + WebSocketContainer getContainer(); + + /** + * Registers a {@link MessageHandler} for incoming messages. Only one + * {@link MessageHandler} may be registered for each message type (text, + * binary, pong). The message type will be derived at runtime from the + * provided {@link MessageHandler} instance. It is not always possible to do + * this so it is better to use + * {@link #addMessageHandler(Class, javax.websocket.MessageHandler.Partial)} + * or + * {@link #addMessageHandler(Class, javax.websocket.MessageHandler.Whole)}. + * + * @param handler The message handler for a incoming message + * + * @throws IllegalStateException If a message handler has already been + * registered for the associated message type + */ + void addMessageHandler(MessageHandler handler) throws IllegalStateException; + + Set<MessageHandler> getMessageHandlers(); + + void removeMessageHandler(MessageHandler listener); + + String getProtocolVersion(); + + String getNegotiatedSubprotocol(); + + List<Extension> getNegotiatedExtensions(); + + boolean isSecure(); + + boolean isOpen(); + + /** + * Get the idle timeout for this session. + * @return The current idle timeout for this session in milliseconds. Zero + * or negative values indicate an infinite timeout. + */ + long getMaxIdleTimeout(); + + /** + * Set the idle timeout for this session. + * @param timeout The new idle timeout for this session in milliseconds. + * Zero or negative values indicate an infinite timeout. + */ + void setMaxIdleTimeout(long timeout); + + /** + * Set the current maximum buffer size for binary messages. + * @param max The new maximum buffer size in bytes + */ + void setMaxBinaryMessageBufferSize(int max); + + /** + * Get the current maximum buffer size for binary messages. + * @return The current maximum buffer size in bytes + */ + int getMaxBinaryMessageBufferSize(); + + /** + * Set the maximum buffer size for text messages. + * @param max The new maximum buffer size in characters. + */ + void setMaxTextMessageBufferSize(int max); + + /** + * Get the maximum buffer size for text messages. + * @return The maximum buffer size in characters. + */ + int getMaxTextMessageBufferSize(); + + RemoteEndpoint.Async getAsyncRemote(); + + RemoteEndpoint.Basic getBasicRemote(); + + /** + * Provides a unique identifier for the session. This identifier should not + * be relied upon to be generated from a secure random source. + * @return A unique identifier for the session. + */ + String getId(); + + /** + * Close the connection to the remote end point using the code + * {@link javax.websocket.CloseReason.CloseCodes#NORMAL_CLOSURE} and an + * empty reason phrase. + * + * @throws IOException if an I/O error occurs while the WebSocket session is + * being closed. + */ + @Override + void close() throws IOException; + + + /** + * Close the connection to the remote end point using the specified code + * and reason phrase. + * @param closeReason The reason the WebSocket session is being closed. + * + * @throws IOException if an I/O error occurs while the WebSocket session is + * being closed. + */ + void close(CloseReason closeReason) throws IOException; + + URI getRequestURI(); + + Map<String, List<String>> getRequestParameterMap(); + + String getQueryString(); + + Map<String,String> getPathParameters(); + + Map<String,Object> getUserProperties(); + + Principal getUserPrincipal(); + + /** + * Obtain the set of open sessions associated with the same local endpoint + * as this session. + * + * @return The set of currently open sessions for the local endpoint that + * this session is associated with. + */ + Set<Session> getOpenSessions(); + + /** + * Registers a {@link MessageHandler} for partial incoming messages. Only + * one {@link MessageHandler} may be registered for each message type (text + * or binary, pong messages are never presented as partial messages). + * + * @param <T> The type of message that the given handler is intended + * for + * @param clazz The Class that implements T + * @param handler The message handler for a incoming message + * + * @throws IllegalStateException If a message handler has already been + * registered for the associated message type + * + * @since WebSocket 1.1 + */ + <T> void addMessageHandler(Class<T> clazz, MessageHandler.Partial<T> handler) + throws IllegalStateException; + + /** + * Registers a {@link MessageHandler} for whole incoming messages. Only + * one {@link MessageHandler} may be registered for each message type (text, + * binary, pong). + * + * @param <T> The type of message that the given handler is intended + * for + * @param clazz The Class that implements T + * @param handler The message handler for a incoming message + * + * @throws IllegalStateException If a message handler has already been + * registered for the associated message type + * + * @since WebSocket 1.1 + */ + <T> void addMessageHandler(Class<T> clazz, MessageHandler.Whole<T> handler) + throws IllegalStateException; +} diff --git a/src/java/javax/websocket/SessionException.java b/src/java/javax/websocket/SessionException.java new file mode 100644 index 00000000..428b82ec --- /dev/null +++ b/src/java/javax/websocket/SessionException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +public class SessionException extends Exception { + + private static final long serialVersionUID = 1L; + + private final Session session; + + + public SessionException(String message, Throwable cause, Session session) { + super(message, cause); + this.session = session; + } + + + public Session getSession() { + return session; + } +} diff --git a/src/java/javax/websocket/WebSocketContainer.java b/src/java/javax/websocket/WebSocketContainer.java new file mode 100644 index 00000000..f2da3e43 --- /dev/null +++ b/src/java/javax/websocket/WebSocketContainer.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket; + +import java.io.IOException; +import java.net.URI; +import java.util.Set; + +public interface WebSocketContainer { + + /** + * Get the default timeout for sending a message asynchronously. + * @return The current default timeout in milliseconds. A non-positive value + * means an infinite timeout. + */ + long getDefaultAsyncSendTimeout(); + + /** + * Set the default timeout for sending a message asynchronously. + * @param timeout The new default timeout in milliseconds. A non-positive + * value means an infinite timeout. + */ + void setAsyncSendTimeout(long timeout); + + Session connectToServer(Object endpoint, URI path) + throws DeploymentException, IOException; + + Session connectToServer(Class<?> annotatedEndpointClass, URI path) + throws DeploymentException, IOException; + + /** + * Creates a new connection to the WebSocket. + * + * @param endpoint + * The endpoint instance that will handle responses from the + * server + * @param clientEndpointConfiguration + * Used to configure the new connection + * @param path + * The full URL of the WebSocket endpoint to connect to + * + * @return The WebSocket session for the connection + * + * @throws DeploymentException If the connection cannot be established + * @throws IOException If an I/O occurred while trying to establish the + * connection + */ + Session connectToServer(Endpoint endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException, IOException; + + /** + * Creates a new connection to the WebSocket. + * + * @param endpoint + * An instance of this class will be created to handle responses + * from the server + * @param clientEndpointConfiguration + * Used to configure the new connection + * @param path + * The full URL of the WebSocket endpoint to connect to + * + * @return The WebSocket session for the connection + * + * @throws DeploymentException If the connection cannot be established + * @throws IOException If an I/O occurred while trying to establish the + * connection + */ + Session connectToServer(Class<? extends Endpoint> endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException, IOException; + + /** + * Get the current default session idle timeout. + * @return The current default session idle timeout in milliseconds. Zero or + * negative values indicate an infinite timeout. + */ + long getDefaultMaxSessionIdleTimeout(); + + /** + * Set the default session idle timeout. + * @param timeout The new default session idle timeout in milliseconds. Zero + * or negative values indicate an infinite timeout. + */ + void setDefaultMaxSessionIdleTimeout(long timeout); + + /** + * Get the default maximum buffer size for binary messages. + * @return The current default maximum buffer size in bytes + */ + int getDefaultMaxBinaryMessageBufferSize(); + + /** + * Set the default maximum buffer size for binary messages. + * @param max The new default maximum buffer size in bytes + */ + void setDefaultMaxBinaryMessageBufferSize(int max); + + /** + * Get the default maximum buffer size for text messages. + * @return The current default maximum buffer size in characters + */ + int getDefaultMaxTextMessageBufferSize(); + + /** + * Set the default maximum buffer size for text messages. + * @param max The new default maximum buffer size in characters + */ + void setDefaultMaxTextMessageBufferSize(int max); + + /** + * Get the installed extensions. + * @return The set of extensions that are supported by this WebSocket + * implementation. + */ + Set<Extension> getInstalledExtensions(); +} diff --git a/src/java/javax/websocket/server/DefaultServerEndpointConfig.java b/src/java/javax/websocket/server/DefaultServerEndpointConfig.java new file mode 100644 index 00000000..7c3b8d7d --- /dev/null +++ b/src/java/javax/websocket/server/DefaultServerEndpointConfig.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; +import javax.websocket.Extension; + +/** + * Provides the default configuration for WebSocket server endpoints. + */ +final class DefaultServerEndpointConfig implements ServerEndpointConfig { + + private final Class<?> endpointClass; + private final String path; + private final List<String> subprotocols; + private final List<Extension> extensions; + private final List<Class<? extends Encoder>> encoders; + private final List<Class<? extends Decoder>> decoders; + private final Configurator serverEndpointConfigurator; + private final Map<String,Object> userProperties = new ConcurrentHashMap<>(); + + DefaultServerEndpointConfig( + Class<?> endpointClass, String path, + List<String> subprotocols, List<Extension> extensions, + List<Class<? extends Encoder>> encoders, + List<Class<? extends Decoder>> decoders, + Configurator serverEndpointConfigurator) { + this.endpointClass = endpointClass; + this.path = path; + this.subprotocols = subprotocols; + this.extensions = extensions; + this.encoders = encoders; + this.decoders = decoders; + this.serverEndpointConfigurator = serverEndpointConfigurator; + } + + @Override + public Class<?> getEndpointClass() { + return endpointClass; + } + + @Override + public List<Class<? extends Encoder>> getEncoders() { + return this.encoders; + } + + @Override + public List<Class<? extends Decoder>> getDecoders() { + return this.decoders; + } + + @Override + public String getPath() { + return path; + } + + @Override + public Configurator getConfigurator() { + return serverEndpointConfigurator; + } + + @Override + public final Map<String, Object> getUserProperties() { + return userProperties; + } + + @Override + public final List<String> getSubprotocols() { + return subprotocols; + } + + @Override + public final List<Extension> getExtensions() { + return extensions; + } +} diff --git a/src/java/javax/websocket/server/HandshakeRequest.java b/src/java/javax/websocket/server/HandshakeRequest.java new file mode 100644 index 00000000..f2e33273 --- /dev/null +++ b/src/java/javax/websocket/server/HandshakeRequest.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.net.URI; +import java.security.Principal; +import java.util.List; +import java.util.Map; + +/** + * Represents the HTTP request that asked to be upgraded to WebSocket. + */ +public interface HandshakeRequest { + + static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; + static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; + static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; + static final String SEC_WEBSOCKET_EXTENSIONS= "Sec-WebSocket-Extensions"; + + Map<String,List<String>> getHeaders(); + + Principal getUserPrincipal(); + + URI getRequestURI(); + + boolean isUserInRole(String role); + + /** + * Get the HTTP Session object associated with this request. Object is used + * to avoid a direct dependency on the Servlet API. + * @return The javax.servlet.http.HttpSession object associated with this + * request, if any. + */ + Object getHttpSession(); + + Map<String, List<String>> getParameterMap(); + + String getQueryString(); +} diff --git a/src/java/javax/websocket/server/PathParam.java b/src/java/javax/websocket/server/PathParam.java new file mode 100644 index 00000000..ff1d085e --- /dev/null +++ b/src/java/javax/websocket/server/PathParam.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Used to annotate method parameters on POJO endpoints the the {@link + * ServerEndpoint} has been defined with a {@link ServerEndpoint#value()} that + * uses a URI template. + */ +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.PARAMETER) +public @interface PathParam { + String value(); +} diff --git a/src/java/javax/websocket/server/ServerApplicationConfig.java b/src/java/javax/websocket/server/ServerApplicationConfig.java new file mode 100644 index 00000000..b91f1c43 --- /dev/null +++ b/src/java/javax/websocket/server/ServerApplicationConfig.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.util.Set; + +import javax.websocket.Endpoint; + +/** + * Applications may provide an implementation of this interface to filter the + * discovered WebSocket endpoints that are deployed. Implementations of this + * class will be discovered via an ServletContainerInitializer scan. + */ +public interface ServerApplicationConfig { + + /** + * Enables applications to filter the discovered implementations of + * {@link ServerEndpointConfig}. + * + * @param scanned The {@link Endpoint} implementations found in the + * application + * @return The set of configurations for the endpoint the application + * wishes to deploy + */ + Set<ServerEndpointConfig> getEndpointConfigs( + Set<Class<? extends Endpoint>> scanned); + + /** + * Enables applications to filter the discovered classes annotated with + * {@link ServerEndpoint}. + * + * @param scanned The POJOs annotated with {@link ServerEndpoint} found in + * the application + * @return The set of POJOs the application wishes to deploy + */ + Set<Class<?>> getAnnotatedEndpointClasses(Set<Class<?>> scanned); +} diff --git a/src/java/javax/websocket/server/ServerContainer.java b/src/java/javax/websocket/server/ServerContainer.java new file mode 100644 index 00000000..3243a07c --- /dev/null +++ b/src/java/javax/websocket/server/ServerContainer.java @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import javax.websocket.DeploymentException; +import javax.websocket.WebSocketContainer; + +/** + * Provides the ability to deploy endpoints programmatically. + */ +public interface ServerContainer extends WebSocketContainer { + public abstract void addEndpoint(Class<?> clazz) throws DeploymentException; + + public abstract void addEndpoint(ServerEndpointConfig sec) + throws DeploymentException; +} diff --git a/src/java/javax/websocket/server/ServerEndpoint.java b/src/java/javax/websocket/server/ServerEndpoint.java new file mode 100644 index 00000000..43b7dfa2 --- /dev/null +++ b/src/java/javax/websocket/server/ServerEndpoint.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; + +@Retention(RetentionPolicy.RUNTIME) +@Target(ElementType.TYPE) +public @interface ServerEndpoint { + + /** + * URI or URI-template that the annotated class should be mapped to. + * @return The URI or URI-template that the annotated class should be mapped + * to. + */ + String value(); + + String[] subprotocols() default {}; + + Class<? extends Decoder>[] decoders() default {}; + + Class<? extends Encoder>[] encoders() default {}; + + public Class<? extends ServerEndpointConfig.Configurator> configurator() + default ServerEndpointConfig.Configurator.class; +} diff --git a/src/java/javax/websocket/server/ServerEndpointConfig.java b/src/java/javax/websocket/server/ServerEndpointConfig.java new file mode 100644 index 00000000..5afdf79c --- /dev/null +++ b/src/java/javax/websocket/server/ServerEndpointConfig.java @@ -0,0 +1,218 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package javax.websocket.server; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.ServiceLoader; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; + +/** + * Provides configuration information for WebSocket endpoints published to a + * server. Applications may provide their own implementation or use + * {@link Builder}. + */ +public interface ServerEndpointConfig extends EndpointConfig { + + Class<?> getEndpointClass(); + + /** + * Returns the path at which this WebSocket server endpoint has been + * registered. It may be a path or a level 0 URI template. + * @return The registered path + */ + String getPath(); + + List<String> getSubprotocols(); + + List<Extension> getExtensions(); + + Configurator getConfigurator(); + + + public final class Builder { + + public static Builder create( + Class<?> endpointClass, String path) { + return new Builder(endpointClass, path); + } + + + private final Class<?> endpointClass; + private final String path; + private List<Class<? extends Encoder>> encoders = + Collections.emptyList(); + private List<Class<? extends Decoder>> decoders = + Collections.emptyList(); + private List<String> subprotocols = Collections.emptyList(); + private List<Extension> extensions = Collections.emptyList(); + private Configurator configurator = + Configurator.fetchContainerDefaultConfigurator(); + + + private Builder(Class<?> endpointClass, + String path) { + this.endpointClass = endpointClass; + this.path = path; + } + + public ServerEndpointConfig build() { + return new DefaultServerEndpointConfig(endpointClass, path, + subprotocols, extensions, encoders, decoders, configurator); + } + + + public Builder encoders( + List<Class<? extends Encoder>> encoders) { + if (encoders == null || encoders.size() == 0) { + this.encoders = Collections.emptyList(); + } else { + this.encoders = Collections.unmodifiableList(encoders); + } + return this; + } + + + public Builder decoders( + List<Class<? extends Decoder>> decoders) { + if (decoders == null || decoders.size() == 0) { + this.decoders = Collections.emptyList(); + } else { + this.decoders = Collections.unmodifiableList(decoders); + } + return this; + } + + + public Builder subprotocols( + List<String> subprotocols) { + if (subprotocols == null || subprotocols.size() == 0) { + this.subprotocols = Collections.emptyList(); + } else { + this.subprotocols = Collections.unmodifiableList(subprotocols); + } + return this; + } + + + public Builder extensions( + List<Extension> extensions) { + if (extensions == null || extensions.size() == 0) { + this.extensions = Collections.emptyList(); + } else { + this.extensions = Collections.unmodifiableList(extensions); + } + return this; + } + + + public Builder configurator(Configurator serverEndpointConfigurator) { + if (serverEndpointConfigurator == null) { + this.configurator = Configurator.fetchContainerDefaultConfigurator(); + } else { + this.configurator = serverEndpointConfigurator; + } + return this; + } + } + + + public class Configurator { + + private static volatile Configurator defaultImpl = null; + private static final Object defaultImplLock = new Object(); + + private static final String DEFAULT_IMPL_CLASSNAME = + "nginx.unit.websocket.server.DefaultServerEndpointConfigurator"; + + public static void setDefault(Configurator def) { + synchronized (defaultImplLock) { + defaultImpl = def; + } + } + + static Configurator fetchContainerDefaultConfigurator() { + if (defaultImpl == null) { + synchronized (defaultImplLock) { + if (defaultImpl == null) { + defaultImpl = loadDefault(); + } + } + } + return defaultImpl; + } + + + private static Configurator loadDefault() { + Configurator result = null; + + ServiceLoader<Configurator> serviceLoader = + ServiceLoader.load(Configurator.class); + + Iterator<Configurator> iter = serviceLoader.iterator(); + while (result == null && iter.hasNext()) { + result = iter.next(); + } + + // Fall-back. Also used by unit tests + if (result == null) { + try { + @SuppressWarnings("unchecked") + Class<Configurator> clazz = + (Class<Configurator>) Class.forName( + DEFAULT_IMPL_CLASSNAME); + result = clazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException | IllegalArgumentException | + SecurityException e) { + // No options left. Just return null. + } + } + return result; + } + + public String getNegotiatedSubprotocol(List<String> supported, + List<String> requested) { + return fetchContainerDefaultConfigurator().getNegotiatedSubprotocol(supported, requested); + } + + public List<Extension> getNegotiatedExtensions(List<Extension> installed, + List<Extension> requested) { + return fetchContainerDefaultConfigurator().getNegotiatedExtensions(installed, requested); + } + + public boolean checkOrigin(String originHeaderValue) { + return fetchContainerDefaultConfigurator().checkOrigin(originHeaderValue); + } + + public void modifyHandshake(ServerEndpointConfig sec, + HandshakeRequest request, HandshakeResponse response) { + fetchContainerDefaultConfigurator().modifyHandshake(sec, request, response); + } + + public <T extends Object> T getEndpointInstance(Class<T> clazz) + throws InstantiationException { + return fetchContainerDefaultConfigurator().getEndpointInstance( + clazz); + } + } +} diff --git a/src/java/nginx/unit/Context.java b/src/java/nginx/unit/Context.java index e1482903..6fcd6018 100644 --- a/src/java/nginx/unit/Context.java +++ b/src/java/nginx/unit/Context.java @@ -98,10 +98,14 @@ import javax.servlet.http.HttpSessionEvent; import javax.servlet.http.HttpSessionIdListener; import javax.servlet.http.HttpSessionListener; +import javax.websocket.server.ServerEndpoint; + import javax.xml.parsers.DocumentBuilder; import javax.xml.parsers.DocumentBuilderFactory; import javax.xml.parsers.ParserConfigurationException; +import nginx.unit.websocket.WsSession; + import org.eclipse.jetty.http.MimeTypes; import org.w3c.dom.Document; @@ -421,6 +425,9 @@ public class Context implements ServletContext, InitParams loader_ = new AppClassLoader(urls, Context.class.getClassLoader().getParent()); + Class wsSession_class = WsSession.class; + trace("wsSession.test: " + WsSession.wsSession_test()); + ClassLoader old = Thread.currentThread().getContextClassLoader(); Thread.currentThread().setContextClassLoader(loader_); @@ -429,28 +436,30 @@ public class Context implements ServletContext, InitParams addListener(listener_classname); } - ScanResult scan_res = null; + ClassGraph classgraph = new ClassGraph() + //.verbose() + .overrideClassLoaders(loader_) + .ignoreParentClassLoaders() + .enableClassInfo() + .enableAnnotationInfo() + //.enableSystemPackages() + .whitelistModules("javax.*") + //.enableAllInfo() + ; - if (!metadata_complete_) { - ClassGraph classgraph = new ClassGraph() - //.verbose() - .overrideClassLoaders(loader_) - .ignoreParentClassLoaders() - .enableClassInfo() - .enableAnnotationInfo() - //.enableSystemPackages() - .whitelistModules("javax.*") - //.enableAllInfo() - ; - - String verbose = System.getProperty("nginx.unit.context.classgraph.verbose", "").trim(); - - if (verbose.equals("true")) { - classgraph.verbose(); - } + String verbose = System.getProperty("nginx.unit.context.classgraph.verbose", "").trim(); + + if (verbose.equals("true")) { + classgraph.verbose(); + } + + ScanResult scan_res = classgraph.scan(); + + javax.websocket.server.ServerEndpointConfig.Configurator.setDefault(new nginx.unit.websocket.server.DefaultServerEndpointConfigurator()); - scan_res = classgraph.scan(); + loadInitializer(new nginx.unit.websocket.server.WsSci(), scan_res); + if (!metadata_complete_) { loadInitializers(scan_res); } @@ -1471,54 +1480,61 @@ public class Context implements ServletContext, InitParams ServiceLoader.load(ServletContainerInitializer.class, loader_); for (ServletContainerInitializer sci : initializers) { + loadInitializer(sci, scan_res); + } + } - trace("loadInitializers: initializer: " + sci.getClass().getName()); + private void loadInitializer(ServletContainerInitializer sci, ScanResult scan_res) + { + trace("loadInitializer: initializer: " + sci.getClass().getName()); - HandlesTypes ann = sci.getClass().getAnnotation(HandlesTypes.class); - if (ann == null) { - trace("loadInitializers: no HandlesTypes annotation"); - continue; - } + HandlesTypes ann = sci.getClass().getAnnotation(HandlesTypes.class); + if (ann == null) { + trace("loadInitializer: no HandlesTypes annotation"); + return; + } - Class<?>[] classes = ann.value(); - if (classes == null) { - trace("loadInitializers: no handles classes"); - continue; - } + Class<?>[] classes = ann.value(); + if (classes == null) { + trace("loadInitializer: no handles classes"); + return; + } - Set<Class<?>> handles_classes = new HashSet<>(); + Set<Class<?>> handles_classes = new HashSet<>(); - for (Class<?> c : classes) { - trace("loadInitializers: find handles: " + c.getName()); + for (Class<?> c : classes) { + trace("loadInitializer: find handles: " + c.getName()); - ClassInfoList handles = c.isInterface() + ClassInfoList handles = + c.isAnnotation() + ? scan_res.getClassesWithAnnotation(c.getName()) + : c.isInterface() ? scan_res.getClassesImplementing(c.getName()) : scan_res.getSubclasses(c.getName()); - for (ClassInfo ci : handles) { - if (ci.isInterface() - || ci.isAnnotation() - || ci.isAbstract()) - { - continue; - } - - trace("loadInitializers: handles class: " + ci.getName()); - handles_classes.add(ci.loadClass()); + for (ClassInfo ci : handles) { + if (ci.isInterface() + || ci.isAnnotation() + || ci.isAbstract()) + { + return; } - } - if (handles_classes.isEmpty()) { - trace("loadInitializers: no handles implementations"); - continue; + trace("loadInitializer: handles class: " + ci.getName()); + handles_classes.add(ci.loadClass()); } + } - try { - sci.onStartup(handles_classes, this); - metadata_complete_ = true; - } catch(Exception e) { - System.err.println("loadInitializers: exception caught: " + e.toString()); - } + if (handles_classes.isEmpty()) { + trace("loadInitializer: no handles implementations"); + return; + } + + try { + sci.onStartup(handles_classes, this); + metadata_complete_ = true; + } catch(Exception e) { + System.err.println("loadInitializer: exception caught: " + e.toString()); } } @@ -1691,6 +1707,21 @@ public class Context implements ServletContext, InitParams listener_classnames_.add(ci.getName()); } + + + ClassInfoList endpoints = scan_res.getClassesWithAnnotation(ServerEndpoint.class.getName()); + + for (ClassInfo ci : endpoints) { + if (ci.isInterface() + || ci.isAnnotation() + || ci.isAbstract()) + { + trace("scanClasses: skip server end point: " + ci.getName()); + continue; + } + + trace("scanClasses: server end point: " + ci.getName()); + } } public void stop() throws IOException diff --git a/src/java/nginx/unit/Request.java b/src/java/nginx/unit/Request.java index 98584efe..335d7980 100644 --- a/src/java/nginx/unit/Request.java +++ b/src/java/nginx/unit/Request.java @@ -16,6 +16,7 @@ import java.lang.StringBuffer; import java.net.URI; import java.net.URISyntaxException; +import java.nio.ByteBuffer; import java.nio.charset.Charset; import java.nio.charset.StandardCharsets; @@ -65,6 +66,9 @@ import org.eclipse.jetty.http.MultiPartFormInputStream; import org.eclipse.jetty.http.HttpFields; import org.eclipse.jetty.http.MimeTypes; +import nginx.unit.websocket.WsSession; +import nginx.unit.websocket.WsIOException; + public class Request implements HttpServletRequest, DynamicPathRequest { private final Context context; @@ -114,6 +118,9 @@ public class Request implements HttpServletRequest, DynamicPathRequest private boolean request_session_id_from_url = false; private Session session = null; + private WsSession wsSession = null; + private boolean skip_close_ws = false; + private final ServletRequestAttributeListener attr_listener; public static final String BARE = "nginx.unit.request.bare"; @@ -1203,11 +1210,30 @@ public class Request implements HttpServletRequest, DynamicPathRequest public <T extends HttpUpgradeHandler> T upgrade( Class<T> httpUpgradeHandlerClass) throws java.io.IOException, ServletException { - log("upgrade: " + httpUpgradeHandlerClass.getName()); + trace("upgrade: " + httpUpgradeHandlerClass.getName()); - return null; + T handler; + + try { + handler = httpUpgradeHandlerClass.getConstructor().newInstance(); + } catch (Exception e) { + throw new ServletException(e); + } + + upgrade(req_info_ptr); + + return handler; + } + + private static native void upgrade(long req_info_ptr); + + public boolean isUpgrade() + { + return isUpgrade(req_info_ptr); } + private static native boolean isUpgrade(long req_info_ptr); + @Override public String changeSessionId() { @@ -1248,5 +1274,65 @@ public class Request implements HttpServletRequest, DynamicPathRequest public static native void trace(long req_info_ptr, String msg, int msg_len); private static native Response getResponse(long req_info_ptr); + + + public void setWsSession(WsSession s) + { + wsSession = s; + } + + private void processWsFrame(ByteBuffer buf, byte opCode, boolean last) + throws IOException + { + trace("processWsFrame: " + opCode + ", [" + buf.position() + ", " + buf.limit() + "]"); + try { + wsSession.processFrame(buf, opCode, last); + } catch (WsIOException e) { + wsSession.onClose(e.getCloseReason()); + } + } + + private void closeWsSession() + { + trace("closeWsSession"); + skip_close_ws = true; + + wsSession.onClose(); + } + + public void sendWsFrame(ByteBuffer payload, byte opCode, boolean last, + long timeoutExpiry) throws IOException + { + trace("sendWsFrame: " + opCode + ", [" + payload.position() + + ", " + payload.limit() + "]"); + + if (payload.isDirect()) { + sendWsFrame(req_info_ptr, payload, payload.position(), + payload.limit() - payload.position(), opCode, last); + } else { + sendWsFrame(req_info_ptr, payload.array(), payload.position(), + payload.limit() - payload.position(), opCode, last); + } + } + + private static native void sendWsFrame(long req_info_ptr, + ByteBuffer buf, int pos, int len, byte opCode, boolean last); + + private static native void sendWsFrame(long req_info_ptr, + byte[] arr, int pos, int len, byte opCode, boolean last); + + + public void closeWs() + { + if (skip_close_ws) { + return; + } + + trace("closeWs"); + + closeWs(req_info_ptr); + } + + private static native void closeWs(long req_info_ptr); } diff --git a/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java b/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java new file mode 100644 index 00000000..147112c1 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelGroupUtil.java @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.channels.AsynchronousChannelGroup; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.threads.ThreadPoolExecutor; + +/** + * This is a utility class that enables multiple {@link WsWebSocketContainer} + * instances to share a single {@link AsynchronousChannelGroup} while ensuring + * that the group is destroyed when no longer required. + */ +public class AsyncChannelGroupUtil { + + private static final StringManager sm = + StringManager.getManager(AsyncChannelGroupUtil.class); + + private static AsynchronousChannelGroup group = null; + private static int usageCount = 0; + private static final Object lock = new Object(); + + + private AsyncChannelGroupUtil() { + // Hide the default constructor + } + + + public static AsynchronousChannelGroup register() { + synchronized (lock) { + if (usageCount == 0) { + group = createAsynchronousChannelGroup(); + } + usageCount++; + return group; + } + } + + + public static void unregister() { + synchronized (lock) { + usageCount--; + if (usageCount == 0) { + group.shutdown(); + group = null; + } + } + } + + + private static AsynchronousChannelGroup createAsynchronousChannelGroup() { + // Need to do this with the right thread context class loader else the + // first web app to call this will trigger a leak + ClassLoader original = Thread.currentThread().getContextClassLoader(); + + try { + Thread.currentThread().setContextClassLoader( + AsyncIOThreadFactory.class.getClassLoader()); + + // These are the same settings as the default + // AsynchronousChannelGroup + int initialSize = Runtime.getRuntime().availableProcessors(); + ExecutorService executorService = new ThreadPoolExecutor( + 0, + Integer.MAX_VALUE, + Long.MAX_VALUE, TimeUnit.MILLISECONDS, + new SynchronousQueue<Runnable>(), + new AsyncIOThreadFactory()); + + try { + return AsynchronousChannelGroup.withCachedThreadPool( + executorService, initialSize); + } catch (IOException e) { + // No good reason for this to happen. + throw new IllegalStateException(sm.getString("asyncChannelGroup.createFail")); + } + } finally { + Thread.currentThread().setContextClassLoader(original); + } + } + + + private static class AsyncIOThreadFactory implements ThreadFactory { + + static { + // Load NewThreadPrivilegedAction since newThread() will not be able + // to if called from an InnocuousThread. + // See https://bz.apache.org/bugzilla/show_bug.cgi?id=57490 + NewThreadPrivilegedAction.load(); + } + + + @Override + public Thread newThread(final Runnable r) { + // Create the new Thread within a doPrivileged block to ensure that + // the thread inherits the current ProtectionDomain which is + // essential to be able to use this with a Java Applet. See + // https://bz.apache.org/bugzilla/show_bug.cgi?id=57091 + return AccessController.doPrivileged(new NewThreadPrivilegedAction(r)); + } + + // Non-anonymous class so that AsyncIOThreadFactory can load it + // explicitly + private static class NewThreadPrivilegedAction implements PrivilegedAction<Thread> { + + private static AtomicInteger count = new AtomicInteger(0); + + private final Runnable r; + + public NewThreadPrivilegedAction(Runnable r) { + this.r = r; + } + + @Override + public Thread run() { + Thread t = new Thread(r); + t.setName("WebSocketClient-AsyncIO-" + count.incrementAndGet()); + t.setContextClassLoader(this.getClass().getClassLoader()); + t.setDaemon(true); + return t; + } + + private static void load() { + // NO-OP. Just provides a hook to enable the class to be loaded + } + } + } +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapper.java b/src/java/nginx/unit/websocket/AsyncChannelWrapper.java new file mode 100644 index 00000000..060ae9cb --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapper.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLException; + +/** + * This is a wrapper for a {@link java.nio.channels.AsynchronousSocketChannel} + * that limits the methods available thereby simplifying the process of + * implementing SSL/TLS support since there are fewer methods to intercept. + */ +public interface AsyncChannelWrapper { + + Future<Integer> read(ByteBuffer dst); + + <B,A extends B> void read(ByteBuffer dst, A attachment, + CompletionHandler<Integer,B> handler); + + Future<Integer> write(ByteBuffer src); + + <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler<Long,B> handler); + + void close(); + + Future<Void> handshake() throws SSLException; +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java new file mode 100644 index 00000000..5b88bfe1 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperNonSecure.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * Generally, just passes calls straight to the wrapped + * {@link AsynchronousSocketChannel}. In some cases exceptions may be swallowed + * to save them being swallowed by the calling code. + */ +public class AsyncChannelWrapperNonSecure implements AsyncChannelWrapper { + + private static final Future<Void> NOOP_FUTURE = new NoOpFuture(); + + private final AsynchronousSocketChannel socketChannel; + + public AsyncChannelWrapperNonSecure( + AsynchronousSocketChannel socketChannel) { + this.socketChannel = socketChannel; + } + + @Override + public Future<Integer> read(ByteBuffer dst) { + return socketChannel.read(dst); + } + + @Override + public <B,A extends B> void read(ByteBuffer dst, A attachment, + CompletionHandler<Integer,B> handler) { + socketChannel.read(dst, attachment, handler); + } + + @Override + public Future<Integer> write(ByteBuffer src) { + return socketChannel.write(src); + } + + @Override + public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler<Long,B> handler) { + socketChannel.write( + srcs, offset, length, timeout, unit, attachment, handler); + } + + @Override + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + // Ignore + } + } + + @Override + public Future<Void> handshake() { + return NOOP_FUTURE; + } + + + private static final class NoOpFuture implements Future<Void> { + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return true; + } + + @Override + public Void get() throws InterruptedException, ExecutionException { + return null; + } + + @Override + public Void get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + return null; + } + } +} diff --git a/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java new file mode 100644 index 00000000..21654487 --- /dev/null +++ b/src/java/nginx/unit/websocket/AsyncChannelWrapperSecure.java @@ -0,0 +1,578 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.channels.CompletionHandler; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +/** + * Wraps the {@link AsynchronousSocketChannel} with SSL/TLS. This needs a lot + * more testing before it can be considered robust. + */ +public class AsyncChannelWrapperSecure implements AsyncChannelWrapper { + + private final Log log = + LogFactory.getLog(AsyncChannelWrapperSecure.class); + private static final StringManager sm = + StringManager.getManager(AsyncChannelWrapperSecure.class); + + private static final ByteBuffer DUMMY = ByteBuffer.allocate(16921); + private final AsynchronousSocketChannel socketChannel; + private final SSLEngine sslEngine; + private final ByteBuffer socketReadBuffer; + private final ByteBuffer socketWriteBuffer; + // One thread for read, one for write + private final ExecutorService executor = + Executors.newFixedThreadPool(2, new SecureIOThreadFactory()); + private AtomicBoolean writing = new AtomicBoolean(false); + private AtomicBoolean reading = new AtomicBoolean(false); + + public AsyncChannelWrapperSecure(AsynchronousSocketChannel socketChannel, + SSLEngine sslEngine) { + this.socketChannel = socketChannel; + this.sslEngine = sslEngine; + + int socketBufferSize = sslEngine.getSession().getPacketBufferSize(); + socketReadBuffer = ByteBuffer.allocateDirect(socketBufferSize); + socketWriteBuffer = ByteBuffer.allocateDirect(socketBufferSize); + } + + @Override + public Future<Integer> read(ByteBuffer dst) { + WrapperFuture<Integer,Void> future = new WrapperFuture<>(); + + if (!reading.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentRead")); + } + + ReadTask readTask = new ReadTask(dst, future); + + executor.execute(readTask); + + return future; + } + + @Override + public <B,A extends B> void read(ByteBuffer dst, A attachment, + CompletionHandler<Integer,B> handler) { + + WrapperFuture<Integer,B> future = + new WrapperFuture<>(handler, attachment); + + if (!reading.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentRead")); + } + + ReadTask readTask = new ReadTask(dst, future); + + executor.execute(readTask); + } + + @Override + public Future<Integer> write(ByteBuffer src) { + + WrapperFuture<Long,Void> inner = new WrapperFuture<>(); + + if (!writing.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentWrite")); + } + + WriteTask writeTask = + new WriteTask(new ByteBuffer[] {src}, 0, 1, inner); + + executor.execute(writeTask); + + Future<Integer> future = new LongToIntegerFuture(inner); + return future; + } + + @Override + public <B,A extends B> void write(ByteBuffer[] srcs, int offset, int length, + long timeout, TimeUnit unit, A attachment, + CompletionHandler<Long,B> handler) { + + WrapperFuture<Long,B> future = + new WrapperFuture<>(handler, attachment); + + if (!writing.compareAndSet(false, true)) { + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.concurrentWrite")); + } + + WriteTask writeTask = new WriteTask(srcs, offset, length, future); + + executor.execute(writeTask); + } + + @Override + public void close() { + try { + socketChannel.close(); + } catch (IOException e) { + log.info(sm.getString("asyncChannelWrapperSecure.closeFail")); + } + executor.shutdownNow(); + } + + @Override + public Future<Void> handshake() throws SSLException { + + WrapperFuture<Void,Void> wFuture = new WrapperFuture<>(); + + Thread t = new WebSocketSslHandshakeThread(wFuture); + t.start(); + + return wFuture; + } + + + private class WriteTask implements Runnable { + + private final ByteBuffer[] srcs; + private final int offset; + private final int length; + private final WrapperFuture<Long,?> future; + + public WriteTask(ByteBuffer[] srcs, int offset, int length, + WrapperFuture<Long,?> future) { + this.srcs = srcs; + this.future = future; + this.offset = offset; + this.length = length; + } + + @Override + public void run() { + long written = 0; + + try { + for (int i = offset; i < offset + length; i++) { + ByteBuffer src = srcs[i]; + while (src.hasRemaining()) { + socketWriteBuffer.clear(); + + // Encrypt the data + SSLEngineResult r = sslEngine.wrap(src, socketWriteBuffer); + written += r.bytesConsumed(); + Status s = r.getStatus(); + + if (s == Status.OK || s == Status.BUFFER_OVERFLOW) { + // Need to write out the bytes and may need to read from + // the source again to empty it + } else { + // Status.BUFFER_UNDERFLOW - only happens on unwrap + // Status.CLOSED - unexpected + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.statusWrap")); + } + + // Check for tasks + if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable runnable = sslEngine.getDelegatedTask(); + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + } + + socketWriteBuffer.flip(); + + // Do the write + int toWrite = r.bytesProduced(); + while (toWrite > 0) { + Future<Integer> f = + socketChannel.write(socketWriteBuffer); + Integer socketWrite = f.get(); + toWrite -= socketWrite.intValue(); + } + } + } + + + if (writing.compareAndSet(true, false)) { + future.complete(Long.valueOf(written)); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateWrite"))); + } + } catch (Exception e) { + writing.set(false); + future.fail(e); + } + } + } + + + private class ReadTask implements Runnable { + + private final ByteBuffer dest; + private final WrapperFuture<Integer,?> future; + + public ReadTask(ByteBuffer dest, WrapperFuture<Integer,?> future) { + this.dest = dest; + this.future = future; + } + + @Override + public void run() { + int read = 0; + + boolean forceRead = false; + + try { + while (read == 0) { + socketReadBuffer.compact(); + + if (forceRead) { + forceRead = false; + Future<Integer> f = socketChannel.read(socketReadBuffer); + Integer socketRead = f.get(); + if (socketRead.intValue() == -1) { + throw new EOFException(sm.getString("asyncChannelWrapperSecure.eof")); + } + } + + socketReadBuffer.flip(); + + if (socketReadBuffer.hasRemaining()) { + // Decrypt the data in the buffer + SSLEngineResult r = sslEngine.unwrap(socketReadBuffer, dest); + read += r.bytesProduced(); + Status s = r.getStatus(); + + if (s == Status.OK) { + // Bytes available for reading and there may be + // sufficient data in the socketReadBuffer to + // support further reads without reading from the + // socket + } else if (s == Status.BUFFER_UNDERFLOW) { + // There is partial data in the socketReadBuffer + if (read == 0) { + // Need more data before the partial data can be + // processed and some output generated + forceRead = true; + } + // else return the data we have and deal with the + // partial data on the next read + } else if (s == Status.BUFFER_OVERFLOW) { + // Not enough space in the destination buffer to + // store all of the data. We could use a bytes read + // value of -bufferSizeRequired to signal the new + // buffer size required but an explicit exception is + // clearer. + if (reading.compareAndSet(true, false)) { + throw new ReadBufferOverflowException(sslEngine. + getSession().getApplicationBufferSize()); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateRead"))); + } + } else { + // Status.CLOSED - unexpected + throw new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.statusUnwrap")); + } + + // Check for tasks + if (r.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable runnable = sslEngine.getDelegatedTask(); + while (runnable != null) { + runnable.run(); + runnable = sslEngine.getDelegatedTask(); + } + } + } else { + forceRead = true; + } + } + + + if (reading.compareAndSet(true, false)) { + future.complete(Integer.valueOf(read)); + } else { + future.fail(new IllegalStateException(sm.getString( + "asyncChannelWrapperSecure.wrongStateRead"))); + } + } catch (RuntimeException | ReadBufferOverflowException | SSLException | EOFException | + ExecutionException | InterruptedException e) { + reading.set(false); + future.fail(e); + } + } + } + + + private class WebSocketSslHandshakeThread extends Thread { + + private final WrapperFuture<Void,Void> hFuture; + + private HandshakeStatus handshakeStatus; + private Status resultStatus; + + public WebSocketSslHandshakeThread(WrapperFuture<Void,Void> hFuture) { + this.hFuture = hFuture; + } + + @Override + public void run() { + try { + sslEngine.beginHandshake(); + // So the first compact does the right thing + socketReadBuffer.position(socketReadBuffer.limit()); + + handshakeStatus = sslEngine.getHandshakeStatus(); + resultStatus = Status.OK; + + boolean handshaking = true; + + while(handshaking) { + switch (handshakeStatus) { + case NEED_WRAP: { + socketWriteBuffer.clear(); + SSLEngineResult r = + sslEngine.wrap(DUMMY, socketWriteBuffer); + checkResult(r, true); + socketWriteBuffer.flip(); + Future<Integer> fWrite = + socketChannel.write(socketWriteBuffer); + fWrite.get(); + break; + } + case NEED_UNWRAP: { + socketReadBuffer.compact(); + if (socketReadBuffer.position() == 0 || + resultStatus == Status.BUFFER_UNDERFLOW) { + Future<Integer> fRead = + socketChannel.read(socketReadBuffer); + fRead.get(); + } + socketReadBuffer.flip(); + SSLEngineResult r = + sslEngine.unwrap(socketReadBuffer, DUMMY); + checkResult(r, false); + break; + } + case NEED_TASK: { + Runnable r = null; + while ((r = sslEngine.getDelegatedTask()) != null) { + r.run(); + } + handshakeStatus = sslEngine.getHandshakeStatus(); + break; + } + case FINISHED: { + handshaking = false; + break; + } + case NOT_HANDSHAKING: { + throw new SSLException( + sm.getString("asyncChannelWrapperSecure.notHandshaking")); + } + } + } + } catch (Exception e) { + hFuture.fail(e); + return; + } + + hFuture.complete(null); + } + + private void checkResult(SSLEngineResult result, boolean wrap) + throws SSLException { + + handshakeStatus = result.getHandshakeStatus(); + resultStatus = result.getStatus(); + + if (resultStatus != Status.OK && + (wrap || resultStatus != Status.BUFFER_UNDERFLOW)) { + throw new SSLException( + sm.getString("asyncChannelWrapperSecure.check.notOk", resultStatus)); + } + if (wrap && result.bytesConsumed() != 0) { + throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.wrap")); + } + if (!wrap && result.bytesProduced() != 0) { + throw new SSLException(sm.getString("asyncChannelWrapperSecure.check.unwrap")); + } + } + } + + + private static class WrapperFuture<T,A> implements Future<T> { + + private final CompletionHandler<T,A> handler; + private final A attachment; + + private volatile T result = null; + private volatile Throwable throwable = null; + private CountDownLatch completionLatch = new CountDownLatch(1); + + public WrapperFuture() { + this(null, null); + } + + public WrapperFuture(CompletionHandler<T,A> handler, A attachment) { + this.handler = handler; + this.attachment = attachment; + } + + public void complete(T result) { + this.result = result; + completionLatch.countDown(); + if (handler != null) { + handler.completed(result, attachment); + } + } + + public void fail(Throwable t) { + throwable = t; + completionLatch.countDown(); + if (handler != null) { + handler.failed(throwable, attachment); + } + } + + @Override + public final boolean cancel(boolean mayInterruptIfRunning) { + // Could support cancellation by closing the connection + return false; + } + + @Override + public final boolean isCancelled() { + // Could support cancellation by closing the connection + return false; + } + + @Override + public final boolean isDone() { + return completionLatch.getCount() > 0; + } + + @Override + public T get() throws InterruptedException, ExecutionException { + completionLatch.await(); + if (throwable != null) { + throw new ExecutionException(throwable); + } + return result; + } + + @Override + public T get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + boolean latchResult = completionLatch.await(timeout, unit); + if (latchResult == false) { + throw new TimeoutException(); + } + if (throwable != null) { + throw new ExecutionException(throwable); + } + return result; + } + } + + private static final class LongToIntegerFuture implements Future<Integer> { + + private final Future<Long> wrapped; + + public LongToIntegerFuture(Future<Long> wrapped) { + this.wrapped = wrapped; + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return wrapped.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return wrapped.isCancelled(); + } + + @Override + public boolean isDone() { + return wrapped.isDone(); + } + + @Override + public Integer get() throws InterruptedException, ExecutionException { + Long result = wrapped.get(); + if (result.longValue() > Integer.MAX_VALUE) { + throw new ExecutionException(sm.getString( + "asyncChannelWrapperSecure.tooBig", result), null); + } + return Integer.valueOf(result.intValue()); + } + + @Override + public Integer get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + Long result = wrapped.get(timeout, unit); + if (result.longValue() > Integer.MAX_VALUE) { + throw new ExecutionException(sm.getString( + "asyncChannelWrapperSecure.tooBig", result), null); + } + return Integer.valueOf(result.intValue()); + } + } + + + private static class SecureIOThreadFactory implements ThreadFactory { + + private AtomicInteger count = new AtomicInteger(0); + + @Override + public Thread newThread(Runnable r) { + Thread t = new Thread(r); + t.setName("WebSocketClient-SecureIO-" + count.incrementAndGet()); + // No need to set the context class loader. The threads will be + // cleaned up when the connection is closed. + t.setDaemon(true); + return t; + } + } +} diff --git a/src/java/nginx/unit/websocket/AuthenticationException.java b/src/java/nginx/unit/websocket/AuthenticationException.java new file mode 100644 index 00000000..001f1829 --- /dev/null +++ b/src/java/nginx/unit/websocket/AuthenticationException.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +/** + * Exception thrown on authentication error connecting to a remote + * websocket endpoint. + */ +public class AuthenticationException extends Exception { + + private static final long serialVersionUID = 5709887412240096441L; + + /** + * Create authentication exception. + * @param message the error message + */ + public AuthenticationException(String message) { + super(message); + } + +} diff --git a/src/java/nginx/unit/websocket/Authenticator.java b/src/java/nginx/unit/websocket/Authenticator.java new file mode 100644 index 00000000..87b3ce6d --- /dev/null +++ b/src/java/nginx/unit/websocket/Authenticator.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.HashMap; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Base class for the authentication methods used by the websocket client. + */ +public abstract class Authenticator { + private static final Pattern pattern = Pattern + .compile("(\\w+)\\s*=\\s*(\"([^\"]+)\"|([^,=\"]+))\\s*,?"); + + /** + * Generate the authentication header that will be sent to the server. + * @param requestUri The request URI + * @param WWWAuthenticate The server auth challenge + * @param UserProperties The user information + * @return The auth header + * @throws AuthenticationException When an error occurs + */ + public abstract String getAuthorization(String requestUri, String WWWAuthenticate, + Map<String, Object> UserProperties) throws AuthenticationException; + + /** + * Get the authentication method. + * @return the auth scheme + */ + public abstract String getSchemeName(); + + /** + * Utility method to parse the authentication header. + * @param WWWAuthenticate The server auth challenge + * @return the parsed header + */ + public Map<String, String> parseWWWAuthenticateHeader(String WWWAuthenticate) { + + Matcher m = pattern.matcher(WWWAuthenticate); + Map<String, String> challenge = new HashMap<>(); + + while (m.find()) { + String key = m.group(1); + String qtedValue = m.group(3); + String value = m.group(4); + + challenge.put(key, qtedValue != null ? qtedValue : value); + + } + + return challenge; + + } + +} diff --git a/src/java/nginx/unit/websocket/AuthenticatorFactory.java b/src/java/nginx/unit/websocket/AuthenticatorFactory.java new file mode 100644 index 00000000..7d46d7f9 --- /dev/null +++ b/src/java/nginx/unit/websocket/AuthenticatorFactory.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.Iterator; +import java.util.ServiceLoader; + +/** + * Utility method to return the appropriate authenticator according to + * the scheme that the server uses. + */ +public class AuthenticatorFactory { + + /** + * Return a new authenticator instance. + * @param authScheme The scheme used + * @return the authenticator + */ + public static Authenticator getAuthenticator(String authScheme) { + + Authenticator auth = null; + switch (authScheme.toLowerCase()) { + + case BasicAuthenticator.schemeName: + auth = new BasicAuthenticator(); + break; + + case DigestAuthenticator.schemeName: + auth = new DigestAuthenticator(); + break; + + default: + auth = loadAuthenticators(authScheme); + break; + } + + return auth; + + } + + private static Authenticator loadAuthenticators(String authScheme) { + ServiceLoader<Authenticator> serviceLoader = ServiceLoader.load(Authenticator.class); + Iterator<Authenticator> auths = serviceLoader.iterator(); + + while (auths.hasNext()) { + Authenticator auth = auths.next(); + if (auth.getSchemeName().equalsIgnoreCase(authScheme)) + return auth; + } + + return null; + } + +} diff --git a/src/java/nginx/unit/websocket/BackgroundProcess.java b/src/java/nginx/unit/websocket/BackgroundProcess.java new file mode 100644 index 00000000..0d2e1288 --- /dev/null +++ b/src/java/nginx/unit/websocket/BackgroundProcess.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public interface BackgroundProcess { + + void backgroundProcess(); + + void setProcessPeriod(int period); + + int getProcessPeriod(); +} diff --git a/src/java/nginx/unit/websocket/BackgroundProcessManager.java b/src/java/nginx/unit/websocket/BackgroundProcessManager.java new file mode 100644 index 00000000..d8b1b950 --- /dev/null +++ b/src/java/nginx/unit/websocket/BackgroundProcessManager.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.HashSet; +import java.util.Set; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.res.StringManager; + +/** + * Provides a background processing mechanism that triggers roughly once a + * second. The class maintains a thread that only runs when there is at least + * one instance of {@link BackgroundProcess} registered. + */ +public class BackgroundProcessManager { + + private final Log log = + LogFactory.getLog(BackgroundProcessManager.class); + private static final StringManager sm = + StringManager.getManager(BackgroundProcessManager.class); + private static final BackgroundProcessManager instance; + + + static { + instance = new BackgroundProcessManager(); + } + + + public static BackgroundProcessManager getInstance() { + return instance; + } + + private final Set<BackgroundProcess> processes = new HashSet<>(); + private final Object processesLock = new Object(); + private WsBackgroundThread wsBackgroundThread = null; + + private BackgroundProcessManager() { + // Hide default constructor + } + + + public void register(BackgroundProcess process) { + synchronized (processesLock) { + if (processes.size() == 0) { + wsBackgroundThread = new WsBackgroundThread(this); + wsBackgroundThread.setContextClassLoader( + this.getClass().getClassLoader()); + wsBackgroundThread.setDaemon(true); + wsBackgroundThread.start(); + } + processes.add(process); + } + } + + + public void unregister(BackgroundProcess process) { + synchronized (processesLock) { + processes.remove(process); + if (wsBackgroundThread != null && processes.size() == 0) { + wsBackgroundThread.halt(); + wsBackgroundThread = null; + } + } + } + + + private void process() { + Set<BackgroundProcess> currentProcesses = new HashSet<>(); + synchronized (processesLock) { + currentProcesses.addAll(processes); + } + for (BackgroundProcess process : currentProcesses) { + try { + process.backgroundProcess(); + } catch (Throwable t) { + ExceptionUtils.handleThrowable(t); + log.error(sm.getString( + "backgroundProcessManager.processFailed"), t); + } + } + } + + + /* + * For unit testing. + */ + int getProcessCount() { + synchronized (processesLock) { + return processes.size(); + } + } + + + void shutdown() { + synchronized (processesLock) { + processes.clear(); + if (wsBackgroundThread != null) { + wsBackgroundThread.halt(); + wsBackgroundThread = null; + } + } + } + + + private static class WsBackgroundThread extends Thread { + + private final BackgroundProcessManager manager; + private volatile boolean running = true; + + public WsBackgroundThread(BackgroundProcessManager manager) { + setName("WebSocket background processing"); + this.manager = manager; + } + + @Override + public void run() { + while (running) { + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + // Ignore + } + manager.process(); + } + } + + public void halt() { + setName("WebSocket background processing - stopping"); + running = false; + } + } +} diff --git a/src/java/nginx/unit/websocket/BasicAuthenticator.java b/src/java/nginx/unit/websocket/BasicAuthenticator.java new file mode 100644 index 00000000..1b1a6b83 --- /dev/null +++ b/src/java/nginx/unit/websocket/BasicAuthenticator.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Map; + +/** + * Authenticator supporting the BASIC auth method. + */ +public class BasicAuthenticator extends Authenticator { + + public static final String schemeName = "basic"; + public static final String charsetparam = "charset"; + + @Override + public String getAuthorization(String requestUri, String WWWAuthenticate, + Map<String, Object> userProperties) throws AuthenticationException { + + String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME); + String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD); + + if (userName == null || password == null) { + throw new AuthenticationException( + "Failed to perform Basic authentication due to missing user/password"); + } + + Map<String, String> wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate); + + String userPass = userName + ":" + password; + Charset charset; + + if (wwwAuthenticate.get(charsetparam) != null + && wwwAuthenticate.get(charsetparam).equalsIgnoreCase("UTF-8")) { + charset = StandardCharsets.UTF_8; + } else { + charset = StandardCharsets.ISO_8859_1; + } + + String base64 = Base64.getEncoder().encodeToString(userPass.getBytes(charset)); + + return " Basic " + base64; + } + + @Override + public String getSchemeName() { + return schemeName; + } + +} diff --git a/src/java/nginx/unit/websocket/Constants.java b/src/java/nginx/unit/websocket/Constants.java new file mode 100644 index 00000000..38b22fe0 --- /dev/null +++ b/src/java/nginx/unit/websocket/Constants.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import javax.websocket.Extension; + +/** + * Internal implementation constants. + */ +public class Constants { + + // OP Codes + public static final byte OPCODE_CONTINUATION = 0x00; + public static final byte OPCODE_TEXT = 0x01; + public static final byte OPCODE_BINARY = 0x02; + public static final byte OPCODE_CLOSE = 0x08; + public static final byte OPCODE_PING = 0x09; + public static final byte OPCODE_PONG = 0x0A; + + // Internal OP Codes + // RFC 6455 limits OP Codes to 4 bits so these should never clash + // Always set bit 4 so these will be treated as control codes + static final byte INTERNAL_OPCODE_FLUSH = 0x18; + + // Buffers + static final int DEFAULT_BUFFER_SIZE = Integer.getInteger( + "nginx.unit.websocket.DEFAULT_BUFFER_SIZE", 8 * 1024) + .intValue(); + + // Client connection + /** + * Property name to set to configure the value that is passed to + * {@link javax.net.ssl.SSLEngine#setEnabledProtocols(String[])}. The value + * should be a comma separated string. + */ + public static final String SSL_PROTOCOLS_PROPERTY = + "nginx.unit.websocket.SSL_PROTOCOLS"; + public static final String SSL_TRUSTSTORE_PROPERTY = + "nginx.unit.websocket.SSL_TRUSTSTORE"; + public static final String SSL_TRUSTSTORE_PWD_PROPERTY = + "nginx.unit.websocket.SSL_TRUSTSTORE_PWD"; + public static final String SSL_TRUSTSTORE_PWD_DEFAULT = "changeit"; + /** + * Property name to set to configure used SSLContext. The value should be an + * instance of SSLContext. If this property is present, the SSL_TRUSTSTORE* + * properties are ignored. + */ + public static final String SSL_CONTEXT_PROPERTY = + "nginx.unit.websocket.SSL_CONTEXT"; + /** + * Property name to set to configure the timeout (in milliseconds) when + * establishing a WebSocket connection to server. The default is + * {@link #IO_TIMEOUT_MS_DEFAULT}. + */ + public static final String IO_TIMEOUT_MS_PROPERTY = + "nginx.unit.websocket.IO_TIMEOUT_MS"; + public static final long IO_TIMEOUT_MS_DEFAULT = 5000; + + // RFC 2068 recommended a limit of 5 + // Most browsers have a default limit of 20 + public static final String MAX_REDIRECTIONS_PROPERTY = + "nginx.unit.websocket.MAX_REDIRECTIONS"; + public static final int MAX_REDIRECTIONS_DEFAULT = 20; + + // HTTP upgrade header names and values + public static final String HOST_HEADER_NAME = "Host"; + public static final String UPGRADE_HEADER_NAME = "Upgrade"; + public static final String UPGRADE_HEADER_VALUE = "websocket"; + public static final String ORIGIN_HEADER_NAME = "Origin"; + public static final String CONNECTION_HEADER_NAME = "Connection"; + public static final String CONNECTION_HEADER_VALUE = "upgrade"; + public static final String LOCATION_HEADER_NAME = "Location"; + public static final String AUTHORIZATION_HEADER_NAME = "Authorization"; + public static final String WWW_AUTHENTICATE_HEADER_NAME = "WWW-Authenticate"; + public static final String WS_VERSION_HEADER_NAME = "Sec-WebSocket-Version"; + public static final String WS_VERSION_HEADER_VALUE = "13"; + public static final String WS_KEY_HEADER_NAME = "Sec-WebSocket-Key"; + public static final String WS_PROTOCOL_HEADER_NAME = "Sec-WebSocket-Protocol"; + public static final String WS_EXTENSIONS_HEADER_NAME = "Sec-WebSocket-Extensions"; + + /// HTTP redirection status codes + public static final int MULTIPLE_CHOICES = 300; + public static final int MOVED_PERMANENTLY = 301; + public static final int FOUND = 302; + public static final int SEE_OTHER = 303; + public static final int USE_PROXY = 305; + public static final int TEMPORARY_REDIRECT = 307; + + // Configuration for Origin header in client + static final String DEFAULT_ORIGIN_HEADER_VALUE = + System.getProperty("nginx.unit.websocket.DEFAULT_ORIGIN_HEADER_VALUE"); + + // Configuration for blocking sends + public static final String BLOCKING_SEND_TIMEOUT_PROPERTY = + "nginx.unit.websocket.BLOCKING_SEND_TIMEOUT"; + // Milliseconds so this is 20 seconds + public static final long DEFAULT_BLOCKING_SEND_TIMEOUT = 20 * 1000; + + // Configuration for background processing checks intervals + static final int DEFAULT_PROCESS_PERIOD = Integer.getInteger( + "nginx.unit.websocket.DEFAULT_PROCESS_PERIOD", 10) + .intValue(); + + public static final String WS_AUTHENTICATION_USER_NAME = "nginx.unit.websocket.WS_AUTHENTICATION_USER_NAME"; + public static final String WS_AUTHENTICATION_PASSWORD = "nginx.unit.websocket.WS_AUTHENTICATION_PASSWORD"; + + /* Configuration for extensions + * Note: These options are primarily present to enable this implementation + * to pass compliance tests. They are expected to be removed once + * the WebSocket API includes a mechanism for adding custom extensions + * and disabling built-in extensions. + */ + static final boolean DISABLE_BUILTIN_EXTENSIONS = + Boolean.getBoolean("nginx.unit.websocket.DISABLE_BUILTIN_EXTENSIONS"); + static final boolean ALLOW_UNSUPPORTED_EXTENSIONS = + Boolean.getBoolean("nginx.unit.websocket.ALLOW_UNSUPPORTED_EXTENSIONS"); + + // Configuration for stream behavior + static final boolean STREAMS_DROP_EMPTY_MESSAGES = + Boolean.getBoolean("nginx.unit.websocket.STREAMS_DROP_EMPTY_MESSAGES"); + + public static final boolean STRICT_SPEC_COMPLIANCE = + Boolean.getBoolean("nginx.unit.websocket.STRICT_SPEC_COMPLIANCE"); + + public static final List<Extension> INSTALLED_EXTENSIONS; + + static { + if (DISABLE_BUILTIN_EXTENSIONS) { + INSTALLED_EXTENSIONS = Collections.unmodifiableList(new ArrayList<Extension>()); + } else { + List<Extension> installed = new ArrayList<>(1); + installed.add(new WsExtension("permessage-deflate")); + INSTALLED_EXTENSIONS = Collections.unmodifiableList(installed); + } + } + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/DecoderEntry.java b/src/java/nginx/unit/websocket/DecoderEntry.java new file mode 100644 index 00000000..36112ef4 --- /dev/null +++ b/src/java/nginx/unit/websocket/DecoderEntry.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.Decoder; + +public class DecoderEntry { + + private final Class<?> clazz; + private final Class<? extends Decoder> decoderClazz; + + public DecoderEntry(Class<?> clazz, + Class<? extends Decoder> decoderClazz) { + this.clazz = clazz; + this.decoderClazz = decoderClazz; + } + + public Class<?> getClazz() { + return clazz; + } + + public Class<? extends Decoder> getDecoderClazz() { + return decoderClazz; + } +} diff --git a/src/java/nginx/unit/websocket/DigestAuthenticator.java b/src/java/nginx/unit/websocket/DigestAuthenticator.java new file mode 100644 index 00000000..9530c303 --- /dev/null +++ b/src/java/nginx/unit/websocket/DigestAuthenticator.java @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.Map; + +import org.apache.tomcat.util.security.MD5Encoder; + +/** + * Authenticator supporting the DIGEST auth method. + */ +public class DigestAuthenticator extends Authenticator { + + public static final String schemeName = "digest"; + private SecureRandom cnonceGenerator; + private int nonceCount = 0; + private long cNonce; + + @Override + public String getAuthorization(String requestUri, String WWWAuthenticate, + Map<String, Object> userProperties) throws AuthenticationException { + + String userName = (String) userProperties.get(Constants.WS_AUTHENTICATION_USER_NAME); + String password = (String) userProperties.get(Constants.WS_AUTHENTICATION_PASSWORD); + + if (userName == null || password == null) { + throw new AuthenticationException( + "Failed to perform Digest authentication due to missing user/password"); + } + + Map<String, String> wwwAuthenticate = parseWWWAuthenticateHeader(WWWAuthenticate); + + String realm = wwwAuthenticate.get("realm"); + String nonce = wwwAuthenticate.get("nonce"); + String messageQop = wwwAuthenticate.get("qop"); + String algorithm = wwwAuthenticate.get("algorithm") == null ? "MD5" + : wwwAuthenticate.get("algorithm"); + String opaque = wwwAuthenticate.get("opaque"); + + StringBuilder challenge = new StringBuilder(); + + if (!messageQop.isEmpty()) { + if (cnonceGenerator == null) { + cnonceGenerator = new SecureRandom(); + } + + cNonce = cnonceGenerator.nextLong(); + nonceCount++; + } + + challenge.append("Digest "); + challenge.append("username =\"" + userName + "\","); + challenge.append("realm=\"" + realm + "\","); + challenge.append("nonce=\"" + nonce + "\","); + challenge.append("uri=\"" + requestUri + "\","); + + try { + challenge.append("response=\"" + calculateRequestDigest(requestUri, userName, password, + realm, nonce, messageQop, algorithm) + "\","); + } + + catch (NoSuchAlgorithmException e) { + throw new AuthenticationException( + "Unable to generate request digest " + e.getMessage()); + } + + challenge.append("algorithm=" + algorithm + ","); + challenge.append("opaque=\"" + opaque + "\","); + + if (!messageQop.isEmpty()) { + challenge.append("qop=\"" + messageQop + "\""); + challenge.append(",cnonce=\"" + cNonce + "\","); + challenge.append("nc=" + String.format("%08X", Integer.valueOf(nonceCount))); + } + + return challenge.toString(); + + } + + private String calculateRequestDigest(String requestUri, String userName, String password, + String realm, String nonce, String qop, String algorithm) + throws NoSuchAlgorithmException { + + StringBuilder preDigest = new StringBuilder(); + String A1; + + if (algorithm.equalsIgnoreCase("MD5")) + A1 = userName + ":" + realm + ":" + password; + + else + A1 = encodeMD5(userName + ":" + realm + ":" + password) + ":" + nonce + ":" + cNonce; + + /* + * If the "qop" value is "auth-int", then A2 is: A2 = Method ":" + * digest-uri-value ":" H(entity-body) since we do not have an entity-body, A2 = + * Method ":" digest-uri-value for auth and auth_int + */ + String A2 = "GET:" + requestUri; + + preDigest.append(encodeMD5(A1)); + preDigest.append(":"); + preDigest.append(nonce); + + if (qop.toLowerCase().contains("auth")) { + preDigest.append(":"); + preDigest.append(String.format("%08X", Integer.valueOf(nonceCount))); + preDigest.append(":"); + preDigest.append(String.valueOf(cNonce)); + preDigest.append(":"); + preDigest.append(qop); + } + + preDigest.append(":"); + preDigest.append(encodeMD5(A2)); + + return encodeMD5(preDigest.toString()); + + } + + private String encodeMD5(String value) throws NoSuchAlgorithmException { + byte[] bytesOfMessage = value.getBytes(StandardCharsets.ISO_8859_1); + MessageDigest md = MessageDigest.getInstance("MD5"); + byte[] thedigest = md.digest(bytesOfMessage); + + return MD5Encoder.encode(thedigest); + } + + @Override + public String getSchemeName() { + return schemeName; + } +} diff --git a/src/java/nginx/unit/websocket/FutureToSendHandler.java b/src/java/nginx/unit/websocket/FutureToSendHandler.java new file mode 100644 index 00000000..4a0809cb --- /dev/null +++ b/src/java/nginx/unit/websocket/FutureToSendHandler.java @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.tomcat.util.res.StringManager; + + +/** + * Converts a Future to a SendHandler. + */ +class FutureToSendHandler implements Future<Void>, SendHandler { + + private static final StringManager sm = StringManager.getManager(FutureToSendHandler.class); + + private final CountDownLatch latch = new CountDownLatch(1); + private final WsSession wsSession; + private volatile AtomicReference<SendResult> result = new AtomicReference<>(null); + + public FutureToSendHandler(WsSession wsSession) { + this.wsSession = wsSession; + } + + + // --------------------------------------------------------- SendHandler + + @Override + public void onResult(SendResult result) { + this.result.compareAndSet(null, result); + latch.countDown(); + } + + + // -------------------------------------------------------------- Future + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + // Cancelling the task is not supported + return false; + } + + @Override + public boolean isCancelled() { + // Cancelling the task is not supported + return false; + } + + @Override + public boolean isDone() { + return latch.getCount() == 0; + } + + @Override + public Void get() throws InterruptedException, + ExecutionException { + try { + wsSession.registerFuture(this); + latch.await(); + } finally { + wsSession.unregisterFuture(this); + } + if (result.get().getException() != null) { + throw new ExecutionException(result.get().getException()); + } + return null; + } + + @Override + public Void get(long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, + TimeoutException { + boolean retval = false; + try { + wsSession.registerFuture(this); + retval = latch.await(timeout, unit); + } finally { + wsSession.unregisterFuture(this); + + } + if (retval == false) { + throw new TimeoutException(sm.getString("futureToSendHandler.timeout", + Long.valueOf(timeout), unit.toString().toLowerCase())); + } + if (result.get().getException() != null) { + throw new ExecutionException(result.get().getException()); + } + return null; + } +} diff --git a/src/java/nginx/unit/websocket/LocalStrings.properties b/src/java/nginx/unit/websocket/LocalStrings.properties new file mode 100644 index 00000000..aeafe082 --- /dev/null +++ b/src/java/nginx/unit/websocket/LocalStrings.properties @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +asyncChannelGroup.createFail=Unable to create dedicated AsynchronousChannelGroup for WebSocket clients which is required to prevent memory leaks in complex class loader environments like JavaEE containers + +asyncChannelWrapperSecure.closeFail=Failed to close channel cleanly +asyncChannelWrapperSecure.check.notOk=TLS handshake returned an unexpected status [{0}] +asyncChannelWrapperSecure.check.unwrap=Bytes were written to the output during a read +asyncChannelWrapperSecure.check.wrap=Bytes were consumed from the input during a write +asyncChannelWrapperSecure.concurrentRead=Concurrent read operations are not permitted +asyncChannelWrapperSecure.concurrentWrite=Concurrent write operations are not permitted +asyncChannelWrapperSecure.eof=Unexpected end of stream +asyncChannelWrapperSecure.notHandshaking=Unexpected state [NOT_HANDSHAKING] during TLS handshake +asyncChannelWrapperSecure.readOverflow=Buffer overflow. [{0}] bytes to write into a [{1}] byte buffer that already contained [{2}] bytes. +asyncChannelWrapperSecure.statusUnwrap=Unexpected Status of SSLEngineResult after an unwrap() operation +asyncChannelWrapperSecure.statusWrap=Unexpected Status of SSLEngineResult after a wrap() operation +asyncChannelWrapperSecure.tooBig=The result [{0}] is too big to be expressed as an Integer +asyncChannelWrapperSecure.wrongStateRead=Flag that indicates a read is in progress was found to be false (it should have been true) when trying to complete a read operation +asyncChannelWrapperSecure.wrongStateWrite=Flag that indicates a write is in progress was found to be false (it should have been true) when trying to complete a write operation + +backgroundProcessManager.processFailed=A background process failed + +caseInsensitiveKeyMap.nullKey=Null keys are not permitted + +futureToSendHandler.timeout=Operation timed out after waiting [{0}] [{1}] to complete + +perMessageDeflate.deflateFailed=Failed to decompress a compressed WebSocket frame +perMessageDeflate.duplicateParameter=Duplicate definition of the [{0}] extension parameter +perMessageDeflate.invalidWindowSize=An invalid windows of [{1}] size was specified for [{0}]. Valid values are whole numbers from 8 to 15 inclusive. +perMessageDeflate.unknownParameter=An unknown extension parameter [{0}] was defined + +transformerFactory.unsupportedExtension=The extension [{0}] is not supported + +util.notToken=An illegal extension parameter was specified with name [{0}] and value [{1}] +util.invalidMessageHandler=The message handler provided does not have an onMessage(Object) method +util.invalidType=Unable to coerce value [{0}] to type [{1}]. That type is not supported. +util.unknownDecoderType=The Decoder type [{0}] is not recognized + +# Note the wsFrame.* messages are used as close reasons in WebSocket control +# frames and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsFrame.alreadyResumed=Message receiving has already been resumed. +wsFrame.alreadySuspended=Message receiving has already been suspended. +wsFrame.bufferTooSmall=No async message support and buffer too small. Buffer size: [{0}], Message size: [{1}] +wsFrame.byteToLongFail=Too many bytes ([{0}]) were provided to be converted into a long +wsFrame.closed=New frame received after a close control frame +wsFrame.controlFragmented=A fragmented control frame was received but control frames may not be fragmented +wsFrame.controlPayloadTooBig=A control frame was sent with a payload of size [{0}] which is larger than the maximum permitted of 125 bytes +wsFrame.controlNoFin=A control frame was sent that did not have the fin bit set. Control frames are not permitted to use continuation frames. +wsFrame.illegalReadState=Unexpected read state [{0}] +wsFrame.invalidOpCode= A WebSocket frame was sent with an unrecognised opCode of [{0}] +wsFrame.invalidUtf8=A WebSocket text frame was received that could not be decoded to UTF-8 because it contained invalid byte sequences +wsFrame.invalidUtf8Close=A WebSocket close frame was received with a close reason that contained invalid UTF-8 byte sequences +wsFrame.ioeTriggeredClose=An unrecoverable IOException occurred so the connection was closed +wsFrame.messageTooBig=The message was [{0}] bytes long but the MessageHandler has a limit of [{1}] bytes +wsFrame.noContinuation=A new message was started when a continuation frame was expected +wsFrame.notMasked=The client frame was not masked but all client frames must be masked +wsFrame.oneByteCloseCode=The client sent a close frame with a single byte payload which is not valid +wsFrame.partialHeaderComplete=WebSocket frame received. fin [{0}], rsv [{1}], OpCode [{2}], payload length [{3}] +wsFrame.sessionClosed=The client data cannot be processed because the session has already been closed +wsFrame.suspendRequested=Suspend of the message receiving has already been requested. +wsFrame.textMessageTooBig=The decoded text message was too big for the output buffer and the endpoint does not support partial messages +wsFrame.wrongRsv=The client frame set the reserved bits to [{0}] for a message with opCode [{1}] which was not supported by this endpoint + +wsFrameClient.ioe=Failure while reading data sent by server + +wsHandshakeRequest.invalidUri=The string [{0}] cannot be used to construct a valid URI +wsHandshakeRequest.unknownScheme=The scheme [{0}] in the request is not recognised + +wsRemoteEndpoint.acquireTimeout=The current message was not fully sent within the specified timeout +wsRemoteEndpoint.closed=Message will not be sent because the WebSocket session has been closed +wsRemoteEndpoint.closedDuringMessage=The remainder of the message will not be sent because the WebSocket session has been closed +wsRemoteEndpoint.closedOutputStream=This method may not be called as the OutputStream has been closed +wsRemoteEndpoint.closedWriter=This method may not be called as the Writer has been closed +wsRemoteEndpoint.changeType=When sending a fragmented message, all fragments must be of the same type +wsRemoteEndpoint.concurrentMessageSend=Messages may not be sent concurrently even when using the asynchronous send messages. The client must wait for the previous message to complete before sending the next. +wsRemoteEndpoint.flushOnCloseFailed=Batched messages still enabled after session has been closed. Unable to flush remaining batched message. +wsRemoteEndpoint.invalidEncoder=The specified encoder of type [{0}] could not be instantiated +wsRemoteEndpoint.noEncoder=No encoder specified for object of class [{0}] +wsRemoteEndpoint.nullData=Invalid null data argument +wsRemoteEndpoint.nullHandler=Invalid null handler argument +wsRemoteEndpoint.sendInterrupt=The current thread was interrupted while waiting for a blocking send to complete +wsRemoteEndpoint.tooMuchData=Ping or pong may not send more than 125 bytes +wsRemoteEndpoint.wrongState=The remote endpoint was in state [{0}] which is an invalid state for called method + +# Note the following message is used as a close reason in a WebSocket control +# frame and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsSession.timeout=The WebSocket session [{0}] timeout expired + +wsSession.closed=The WebSocket session [{0}] has been closed and no method (apart from close()) may be called on a closed session +wsSession.created=Created WebSocket session [{0}] +wsSession.doClose=Closing WebSocket session [{1}] +wsSession.duplicateHandlerBinary=A binary message handler has already been configured +wsSession.duplicateHandlerPong=A pong message handler has already been configured +wsSession.duplicateHandlerText=A text message handler has already been configured +wsSession.invalidHandlerTypePong=A pong message handler must implement MessageHandler.Whole +wsSession.flushFailOnClose=Failed to flush batched messages on session close +wsSession.messageFailed=Unable to write the complete message as the WebSocket connection has been closed +wsSession.sendCloseFail=Failed to send close message for session [{0}] to remote endpoint +wsSession.removeHandlerFailed=Unable to remove the handler [{0}] as it was not registered with this session +wsSession.unknownHandler=Unable to add the message handler [{0}] as it was for the unrecognised type [{1}] +wsSession.unknownHandlerType=Unable to add the message handler [{0}] as it was wrapped as the unrecognised type [{1}] +wsSession.instanceNew=Endpoint instance registration failed +wsSession.instanceDestroy=Endpoint instance unregistration failed + +# Note the following message is used as a close reason in a WebSocket control +# frame and therefore must be 123 bytes (not characters) or less in length. +# Messages are encoded using UTF-8 where a single character may be encoded in +# as many as 4 bytes. +wsWebSocketContainer.shutdown=The web application is stopping + +wsWebSocketContainer.defaultConfiguratorFail=Failed to create the default configurator +wsWebSocketContainer.endpointCreateFail=Failed to create a local endpoint of type [{0}] +wsWebSocketContainer.maxBuffer=This implementation limits the maximum size of a buffer to Integer.MAX_VALUE +wsWebSocketContainer.missingAnnotation=Cannot use POJO class [{0}] as it is not annotated with @ClientEndpoint +wsWebSocketContainer.sessionCloseFail=Session with ID [{0}] did not close cleanly + +wsWebSocketContainer.asynchronousSocketChannelFail=Unable to open a connection to the server +wsWebSocketContainer.httpRequestFailed=The HTTP request to initiate the WebSocket connection failed +wsWebSocketContainer.invalidExtensionParameters=The server responded with extension parameters the client is unable to support +wsWebSocketContainer.invalidHeader=Unable to parse HTTP header as no colon is present to delimit header name and header value in [{0}]. The header has been skipped. +wsWebSocketContainer.invalidStatus=The HTTP response from the server [{0}] did not permit the HTTP upgrade to WebSocket +wsWebSocketContainer.invalidSubProtocol=The WebSocket server returned multiple values for the Sec-WebSocket-Protocol header +wsWebSocketContainer.pathNoHost=No host was specified in URI +wsWebSocketContainer.pathWrongScheme=The scheme [{0}] is not supported. The supported schemes are ws and wss +wsWebSocketContainer.proxyConnectFail=Failed to connect to the configured Proxy [{0}]. The HTTP response code was [{1}] +wsWebSocketContainer.sslEngineFail=Unable to create SSLEngine to support SSL/TLS connections +wsWebSocketContainer.missingLocationHeader=Failed to handle HTTP response code [{0}]. Missing Location header in response +wsWebSocketContainer.redirectThreshold=Cyclic Location header [{0}] detected / reached max number of redirects [{1}] of max [{2}] +wsWebSocketContainer.unsupportedAuthScheme=Failed to handle HTTP response code [{0}]. Unsupported Authentication scheme [{1}] returned in response +wsWebSocketContainer.failedAuthentication=Failed to handle HTTP response code [{0}]. Authentication header was not accepted by server. +wsWebSocketContainer.missingWWWAuthenticateHeader=Failed to handle HTTP response code [{0}]. Missing WWW-Authenticate header in response diff --git a/src/java/nginx/unit/websocket/MessageHandlerResult.java b/src/java/nginx/unit/websocket/MessageHandlerResult.java new file mode 100644 index 00000000..8d532d1e --- /dev/null +++ b/src/java/nginx/unit/websocket/MessageHandlerResult.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.MessageHandler; + +public class MessageHandlerResult { + + private final MessageHandler handler; + private final MessageHandlerResultType type; + + + public MessageHandlerResult(MessageHandler handler, + MessageHandlerResultType type) { + this.handler = handler; + this.type = type; + } + + + public MessageHandler getHandler() { + return handler; + } + + + public MessageHandlerResultType getType() { + return type; + } +} diff --git a/src/java/nginx/unit/websocket/MessageHandlerResultType.java b/src/java/nginx/unit/websocket/MessageHandlerResultType.java new file mode 100644 index 00000000..1961bb4f --- /dev/null +++ b/src/java/nginx/unit/websocket/MessageHandlerResultType.java @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public enum MessageHandlerResultType { + BINARY, + TEXT, + PONG +} diff --git a/src/java/nginx/unit/websocket/MessagePart.java b/src/java/nginx/unit/websocket/MessagePart.java new file mode 100644 index 00000000..b52c26f1 --- /dev/null +++ b/src/java/nginx/unit/websocket/MessagePart.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; + +import javax.websocket.SendHandler; + +class MessagePart { + private final boolean fin; + private final int rsv; + private final byte opCode; + private final ByteBuffer payload; + private final SendHandler intermediateHandler; + private volatile SendHandler endHandler; + private final long blockingWriteTimeoutExpiry; + + public MessagePart( boolean fin, int rsv, byte opCode, ByteBuffer payload, + SendHandler intermediateHandler, SendHandler endHandler, + long blockingWriteTimeoutExpiry) { + this.fin = fin; + this.rsv = rsv; + this.opCode = opCode; + this.payload = payload; + this.intermediateHandler = intermediateHandler; + this.endHandler = endHandler; + this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry; + } + + + public boolean isFin() { + return fin; + } + + + public int getRsv() { + return rsv; + } + + + public byte getOpCode() { + return opCode; + } + + + public ByteBuffer getPayload() { + return payload; + } + + + public SendHandler getIntermediateHandler() { + return intermediateHandler; + } + + + public SendHandler getEndHandler() { + return endHandler; + } + + public void setEndHandler(SendHandler endHandler) { + this.endHandler = endHandler; + } + + public long getBlockingWriteTimeoutExpiry() { + return blockingWriteTimeoutExpiry; + } +} + + diff --git a/src/java/nginx/unit/websocket/PerMessageDeflate.java b/src/java/nginx/unit/websocket/PerMessageDeflate.java new file mode 100644 index 00000000..88e0a0bc --- /dev/null +++ b/src/java/nginx/unit/websocket/PerMessageDeflate.java @@ -0,0 +1,476 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; + +import javax.websocket.Extension; +import javax.websocket.Extension.Parameter; +import javax.websocket.SendHandler; + +import org.apache.tomcat.util.res.StringManager; + +public class PerMessageDeflate implements Transformation { + + private static final StringManager sm = StringManager.getManager(PerMessageDeflate.class); + + private static final String SERVER_NO_CONTEXT_TAKEOVER = "server_no_context_takeover"; + private static final String CLIENT_NO_CONTEXT_TAKEOVER = "client_no_context_takeover"; + private static final String SERVER_MAX_WINDOW_BITS = "server_max_window_bits"; + private static final String CLIENT_MAX_WINDOW_BITS = "client_max_window_bits"; + + private static final int RSV_BITMASK = 0b100; + private static final byte[] EOM_BYTES = new byte[] {0, 0, -1, -1}; + + public static final String NAME = "permessage-deflate"; + + private final boolean serverContextTakeover; + private final int serverMaxWindowBits; + private final boolean clientContextTakeover; + private final int clientMaxWindowBits; + private final boolean isServer; + private final Inflater inflater = new Inflater(true); + private final ByteBuffer readBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Deflater deflater = new Deflater(Deflater.DEFAULT_COMPRESSION, true); + private final byte[] EOM_BUFFER = new byte[EOM_BYTES.length + 1]; + + private volatile Transformation next; + private volatile boolean skipDecompression = false; + private volatile ByteBuffer writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private volatile boolean firstCompressedFrameWritten = false; + // Flag to track if a message is completely empty + private volatile boolean emptyMessage = true; + + static PerMessageDeflate negotiate(List<List<Parameter>> preferences, boolean isServer) { + // Accept the first preference that the endpoint is able to support + for (List<Parameter> preference : preferences) { + boolean ok = true; + boolean serverContextTakeover = true; + int serverMaxWindowBits = -1; + boolean clientContextTakeover = true; + int clientMaxWindowBits = -1; + + for (Parameter param : preference) { + if (SERVER_NO_CONTEXT_TAKEOVER.equals(param.getName())) { + if (serverContextTakeover) { + serverContextTakeover = false; + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + SERVER_NO_CONTEXT_TAKEOVER )); + } + } else if (CLIENT_NO_CONTEXT_TAKEOVER.equals(param.getName())) { + if (clientContextTakeover) { + clientContextTakeover = false; + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + CLIENT_NO_CONTEXT_TAKEOVER )); + } + } else if (SERVER_MAX_WINDOW_BITS.equals(param.getName())) { + if (serverMaxWindowBits == -1) { + serverMaxWindowBits = Integer.parseInt(param.getValue()); + if (serverMaxWindowBits < 8 || serverMaxWindowBits > 15) { + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.invalidWindowSize", + SERVER_MAX_WINDOW_BITS, + Integer.valueOf(serverMaxWindowBits))); + } + // Java SE API (as of Java 8) does not expose the API to + // control the Window size. It is effectively hard-coded + // to 15 + if (isServer && serverMaxWindowBits != 15) { + ok = false; + break; + // Note server window size is not an issue for the + // client since the client will assume 15 and if the + // server uses a smaller window everything will + // still work + } + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + SERVER_MAX_WINDOW_BITS )); + } + } else if (CLIENT_MAX_WINDOW_BITS.equals(param.getName())) { + if (clientMaxWindowBits == -1) { + if (param.getValue() == null) { + // Hint to server that the client supports this + // option. Java SE API (as of Java 8) does not + // expose the API to control the Window size. It is + // effectively hard-coded to 15 + clientMaxWindowBits = 15; + } else { + clientMaxWindowBits = Integer.parseInt(param.getValue()); + if (clientMaxWindowBits < 8 || clientMaxWindowBits > 15) { + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.invalidWindowSize", + CLIENT_MAX_WINDOW_BITS, + Integer.valueOf(clientMaxWindowBits))); + } + } + // Java SE API (as of Java 8) does not expose the API to + // control the Window size. It is effectively hard-coded + // to 15 + if (!isServer && clientMaxWindowBits != 15) { + ok = false; + break; + // Note client window size is not an issue for the + // server since the server will assume 15 and if the + // client uses a smaller window everything will + // still work + } + } else { + // Duplicate definition + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.duplicateParameter", + CLIENT_MAX_WINDOW_BITS )); + } + } else { + // Unknown parameter + throw new IllegalArgumentException(sm.getString( + "perMessageDeflate.unknownParameter", param.getName())); + } + } + if (ok) { + return new PerMessageDeflate(serverContextTakeover, serverMaxWindowBits, + clientContextTakeover, clientMaxWindowBits, isServer); + } + } + // Failed to negotiate agreeable terms + return null; + } + + + private PerMessageDeflate(boolean serverContextTakeover, int serverMaxWindowBits, + boolean clientContextTakeover, int clientMaxWindowBits, boolean isServer) { + this.serverContextTakeover = serverContextTakeover; + this.serverMaxWindowBits = serverMaxWindowBits; + this.clientContextTakeover = clientContextTakeover; + this.clientMaxWindowBits = clientMaxWindowBits; + this.isServer = isServer; + } + + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) + throws IOException { + // Control frames are never compressed and may appear in the middle of + // a WebSocket method. Pass them straight through. + if (Util.isControl(opCode)) { + return next.getMoreData(opCode, fin, rsv, dest); + } + + if (!Util.isContinuation(opCode)) { + // First frame in new message + skipDecompression = (rsv & RSV_BITMASK) == 0; + } + + // Pass uncompressed frames straight through. + if (skipDecompression) { + return next.getMoreData(opCode, fin, rsv, dest); + } + + int written; + boolean usedEomBytes = false; + + while (dest.remaining() > 0) { + // Space available in destination. Try and fill it. + try { + written = inflater.inflate( + dest.array(), dest.arrayOffset() + dest.position(), dest.remaining()); + } catch (DataFormatException e) { + throw new IOException(sm.getString("perMessageDeflate.deflateFailed"), e); + } + dest.position(dest.position() + written); + + if (inflater.needsInput() && !usedEomBytes ) { + if (dest.hasRemaining()) { + readBuffer.clear(); + TransformationResult nextResult = + next.getMoreData(opCode, fin, (rsv ^ RSV_BITMASK), readBuffer); + inflater.setInput( + readBuffer.array(), readBuffer.arrayOffset(), readBuffer.position()); + if (TransformationResult.UNDERFLOW.equals(nextResult)) { + return nextResult; + } else if (TransformationResult.END_OF_FRAME.equals(nextResult) && + readBuffer.position() == 0) { + if (fin) { + inflater.setInput(EOM_BYTES); + usedEomBytes = true; + } else { + return TransformationResult.END_OF_FRAME; + } + } + } + } else if (written == 0) { + if (fin && (isServer && !clientContextTakeover || + !isServer && !serverContextTakeover)) { + inflater.reset(); + } + return TransformationResult.END_OF_FRAME; + } + } + + return TransformationResult.OVERFLOW; + } + + + @Override + public boolean validateRsv(int rsv, byte opCode) { + if (Util.isControl(opCode)) { + if ((rsv & RSV_BITMASK) != 0) { + return false; + } else { + if (next == null) { + return true; + } else { + return next.validateRsv(rsv, opCode); + } + } + } else { + int rsvNext = rsv; + if ((rsv & RSV_BITMASK) != 0) { + rsvNext = rsv ^ RSV_BITMASK; + } + if (next == null) { + return true; + } else { + return next.validateRsv(rsvNext, opCode); + } + } + } + + + @Override + public Extension getExtensionResponse() { + Extension result = new WsExtension(NAME); + + List<Extension.Parameter> params = result.getParameters(); + + if (!serverContextTakeover) { + params.add(new WsExtensionParameter(SERVER_NO_CONTEXT_TAKEOVER, null)); + } + if (serverMaxWindowBits != -1) { + params.add(new WsExtensionParameter(SERVER_MAX_WINDOW_BITS, + Integer.toString(serverMaxWindowBits))); + } + if (!clientContextTakeover) { + params.add(new WsExtensionParameter(CLIENT_NO_CONTEXT_TAKEOVER, null)); + } + if (clientMaxWindowBits != -1) { + params.add(new WsExtensionParameter(CLIENT_MAX_WINDOW_BITS, + Integer.toString(clientMaxWindowBits))); + } + + return result; + } + + + @Override + public void setNext(Transformation t) { + if (next == null) { + this.next = t; + } else { + next.setNext(t); + } + } + + + @Override + public boolean validateRsvBits(int i) { + if ((i & RSV_BITMASK) != 0) { + return false; + } + if (next == null) { + return true; + } else { + return next.validateRsvBits(i | RSV_BITMASK); + } + } + + + @Override + public List<MessagePart> sendMessagePart(List<MessagePart> uncompressedParts) { + List<MessagePart> allCompressedParts = new ArrayList<>(); + + for (MessagePart uncompressedPart : uncompressedParts) { + byte opCode = uncompressedPart.getOpCode(); + boolean emptyPart = uncompressedPart.getPayload().limit() == 0; + emptyMessage = emptyMessage && emptyPart; + if (Util.isControl(opCode)) { + // Control messages can appear in the middle of other messages + // and must not be compressed. Pass it straight through + allCompressedParts.add(uncompressedPart); + } else if (emptyMessage && uncompressedPart.isFin()) { + // Zero length messages can't be compressed so pass the + // final (empty) part straight through. + allCompressedParts.add(uncompressedPart); + } else { + List<MessagePart> compressedParts = new ArrayList<>(); + ByteBuffer uncompressedPayload = uncompressedPart.getPayload(); + SendHandler uncompressedIntermediateHandler = + uncompressedPart.getIntermediateHandler(); + + deflater.setInput(uncompressedPayload.array(), + uncompressedPayload.arrayOffset() + uncompressedPayload.position(), + uncompressedPayload.remaining()); + + int flush = (uncompressedPart.isFin() ? Deflater.SYNC_FLUSH : Deflater.NO_FLUSH); + boolean deflateRequired = true; + + while (deflateRequired) { + ByteBuffer compressedPayload = writeBuffer; + + int written = deflater.deflate(compressedPayload.array(), + compressedPayload.arrayOffset() + compressedPayload.position(), + compressedPayload.remaining(), flush); + compressedPayload.position(compressedPayload.position() + written); + + if (!uncompressedPart.isFin() && compressedPayload.hasRemaining() && deflater.needsInput()) { + // This message part has been fully processed by the + // deflater. Fire the send handler for this message part + // and move on to the next message part. + break; + } + + // If this point is reached, a new compressed message part + // will be created... + MessagePart compressedPart; + + // .. and a new writeBuffer will be required. + writeBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + + // Flip the compressed payload ready for writing + compressedPayload.flip(); + + boolean fin = uncompressedPart.isFin(); + boolean full = compressedPayload.limit() == compressedPayload.capacity(); + boolean needsInput = deflater.needsInput(); + long blockingWriteTimeoutExpiry = uncompressedPart.getBlockingWriteTimeoutExpiry(); + + if (fin && !full && needsInput) { + // End of compressed message. Drop EOM bytes and output. + compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length); + compressedPart = new MessagePart(true, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + deflateRequired = false; + startNewMessage(); + } else if (full && !needsInput) { + // Write buffer full and input message not fully read. + // Output and start new compressed part. + compressedPart = new MessagePart(false, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + } else if (!fin && full && needsInput) { + // Write buffer full and input message not fully read. + // Output and get more data. + compressedPart = new MessagePart(false, getRsv(uncompressedPart), + opCode, compressedPayload, uncompressedIntermediateHandler, + uncompressedIntermediateHandler, blockingWriteTimeoutExpiry); + deflateRequired = false; + } else if (fin && full && needsInput) { + // Write buffer full. Input fully read. Deflater may be + // in one of four states: + // - output complete (just happened to align with end of + // buffer + // - in middle of EOM bytes + // - about to write EOM bytes + // - more data to write + int eomBufferWritten = deflater.deflate(EOM_BUFFER, 0, EOM_BUFFER.length, Deflater.SYNC_FLUSH); + if (eomBufferWritten < EOM_BUFFER.length) { + // EOM has just been completed + compressedPayload.limit(compressedPayload.limit() - EOM_BYTES.length + eomBufferWritten); + compressedPart = new MessagePart(true, + getRsv(uncompressedPart), opCode, compressedPayload, + uncompressedIntermediateHandler, uncompressedIntermediateHandler, + blockingWriteTimeoutExpiry); + deflateRequired = false; + startNewMessage(); + } else { + // More data to write + // Copy bytes to new write buffer + writeBuffer.put(EOM_BUFFER, 0, eomBufferWritten); + compressedPart = new MessagePart(false, + getRsv(uncompressedPart), opCode, compressedPayload, + uncompressedIntermediateHandler, uncompressedIntermediateHandler, + blockingWriteTimeoutExpiry); + } + } else { + throw new IllegalStateException("Should never happen"); + } + + // Add the newly created compressed part to the set of parts + // to pass on to the next transformation. + compressedParts.add(compressedPart); + } + + SendHandler uncompressedEndHandler = uncompressedPart.getEndHandler(); + int size = compressedParts.size(); + if (size > 0) { + compressedParts.get(size - 1).setEndHandler(uncompressedEndHandler); + } + + allCompressedParts.addAll(compressedParts); + } + } + + if (next == null) { + return allCompressedParts; + } else { + return next.sendMessagePart(allCompressedParts); + } + } + + + private void startNewMessage() { + firstCompressedFrameWritten = false; + emptyMessage = true; + if (isServer && !serverContextTakeover || !isServer && !clientContextTakeover) { + deflater.reset(); + } + } + + + private int getRsv(MessagePart uncompressedMessagePart) { + int result = uncompressedMessagePart.getRsv(); + if (!firstCompressedFrameWritten) { + result += RSV_BITMASK; + firstCompressedFrameWritten = true; + } + return result; + } + + + @Override + public void close() { + // There will always be a next transformation + next.close(); + inflater.end(); + deflater.end(); + } +} diff --git a/src/java/nginx/unit/websocket/ReadBufferOverflowException.java b/src/java/nginx/unit/websocket/ReadBufferOverflowException.java new file mode 100644 index 00000000..9ce7ac27 --- /dev/null +++ b/src/java/nginx/unit/websocket/ReadBufferOverflowException.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; + +public class ReadBufferOverflowException extends IOException { + + private static final long serialVersionUID = 1L; + + private final int minBufferSize; + + public ReadBufferOverflowException(int minBufferSize) { + this.minBufferSize = minBufferSize; + } + + public int getMinBufferSize() { + return minBufferSize; + } +} diff --git a/src/java/nginx/unit/websocket/Transformation.java b/src/java/nginx/unit/websocket/Transformation.java new file mode 100644 index 00000000..45474c7d --- /dev/null +++ b/src/java/nginx/unit/websocket/Transformation.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.List; + +import javax.websocket.Extension; + +/** + * The internal representation of the transformation that a WebSocket extension + * performs on a message. + */ +public interface Transformation { + + /** + * Sets the next transformation in the pipeline. + * @param t The next transformation + */ + void setNext(Transformation t); + + /** + * Validate that the RSV bit(s) required by this transformation are not + * being used by another extension. The implementation is expected to set + * any bits it requires before passing the set of in-use bits to the next + * transformation. + * + * @param i The RSV bits marked as in use so far as an int in the + * range zero to seven with RSV1 as the MSB and RSV3 as the + * LSB + * + * @return <code>true</code> if the combination of RSV bits used by the + * transformations in the pipeline do not conflict otherwise + * <code>false</code> + */ + boolean validateRsvBits(int i); + + /** + * Obtain the extension that describes the information to be returned to the + * client. + * + * @return The extension information that describes the parameters that have + * been agreed for this transformation + */ + Extension getExtensionResponse(); + + /** + * Obtain more input data. + * + * @param opCode The opcode for the frame currently being processed + * @param fin Is this the final frame in this WebSocket message? + * @param rsv The reserved bits for the frame currently being + * processed + * @param dest The buffer in which the data is to be written + * + * @return The result of trying to read more data from the transform + * + * @throws IOException If an I/O error occurs while reading data from the + * transform + */ + TransformationResult getMoreData(byte opCode, boolean fin, int rsv, ByteBuffer dest) throws IOException; + + /** + * Validates the RSV and opcode combination (assumed to have been extracted + * from a WebSocket Frame) for this extension. The implementation is + * expected to unset any RSV bits it has validated before passing the + * remaining RSV bits to the next transformation in the pipeline. + * + * @param rsv The RSV bits received as an int in the range zero to + * seven with RSV1 as the MSB and RSV3 as the LSB + * @param opCode The opCode received + * + * @return <code>true</code> if the RSV is valid otherwise + * <code>false</code> + */ + boolean validateRsv(int rsv, byte opCode); + + /** + * Takes the provided list of messages, transforms them, passes the + * transformed list on to the next transformation (if any) and then returns + * the resulting list of message parts after all of the transformations have + * been applied. + * + * @param messageParts The list of messages to be transformed + * + * @return The list of messages after this any any subsequent + * transformations have been applied. The size of the returned list + * may be bigger or smaller than the size of the input list + */ + List<MessagePart> sendMessagePart(List<MessagePart> messageParts); + + /** + * Clean-up any resources that were used by the transformation. + */ + void close(); +} diff --git a/src/java/nginx/unit/websocket/TransformationFactory.java b/src/java/nginx/unit/websocket/TransformationFactory.java new file mode 100644 index 00000000..fac04555 --- /dev/null +++ b/src/java/nginx/unit/websocket/TransformationFactory.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.List; + +import javax.websocket.Extension; + +import org.apache.tomcat.util.res.StringManager; + +public class TransformationFactory { + + private static final StringManager sm = StringManager.getManager(TransformationFactory.class); + + private static final TransformationFactory factory = new TransformationFactory(); + + private TransformationFactory() { + // Hide default constructor + } + + public static TransformationFactory getInstance() { + return factory; + } + + public Transformation create(String name, List<List<Extension.Parameter>> preferences, + boolean isServer) { + if (PerMessageDeflate.NAME.equals(name)) { + return PerMessageDeflate.negotiate(preferences, isServer); + } + if (Constants.ALLOW_UNSUPPORTED_EXTENSIONS) { + return null; + } else { + throw new IllegalArgumentException( + sm.getString("transformerFactory.unsupportedExtension", name)); + } + } +} diff --git a/src/java/nginx/unit/websocket/TransformationResult.java b/src/java/nginx/unit/websocket/TransformationResult.java new file mode 100644 index 00000000..0de35e55 --- /dev/null +++ b/src/java/nginx/unit/websocket/TransformationResult.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +public enum TransformationResult { + /** + * The end of the available data was reached before the WebSocket frame was + * completely read. + */ + UNDERFLOW, + + /** + * The provided destination buffer was filled before all of the available + * data from the WebSocket frame could be processed. + */ + OVERFLOW, + + /** + * The end of the WebSocket frame was reached and all the data from that + * frame processed into the provided destination buffer. + */ + END_OF_FRAME +} diff --git a/src/java/nginx/unit/websocket/Util.java b/src/java/nginx/unit/websocket/Util.java new file mode 100644 index 00000000..6acf3ade --- /dev/null +++ b/src/java/nginx/unit/websocket/Util.java @@ -0,0 +1,666 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.InputStream; +import java.io.Reader; +import java.lang.reflect.GenericArrayType; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; + +import javax.websocket.CloseReason.CloseCode; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Binary; +import javax.websocket.Decoder.BinaryStream; +import javax.websocket.Decoder.Text; +import javax.websocket.Decoder.TextStream; +import javax.websocket.DeploymentException; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.PongMessage; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.pojo.PojoMessageHandlerPartialBinary; +import nginx.unit.websocket.pojo.PojoMessageHandlerWholeBinary; +import nginx.unit.websocket.pojo.PojoMessageHandlerWholeText; + +/** + * Utility class for internal use only within the + * {@link nginx.unit.websocket} package. + */ +public class Util { + + private static final StringManager sm = StringManager.getManager(Util.class); + private static final Queue<SecureRandom> randoms = + new ConcurrentLinkedQueue<>(); + + private Util() { + // Hide default constructor + } + + + static boolean isControl(byte opCode) { + return (opCode & 0x08) != 0; + } + + + static boolean isText(byte opCode) { + return opCode == Constants.OPCODE_TEXT; + } + + + static boolean isContinuation(byte opCode) { + return opCode == Constants.OPCODE_CONTINUATION; + } + + + static CloseCode getCloseCode(int code) { + if (code > 2999 && code < 5000) { + return CloseCodes.getCloseCode(code); + } + switch (code) { + case 1000: + return CloseCodes.NORMAL_CLOSURE; + case 1001: + return CloseCodes.GOING_AWAY; + case 1002: + return CloseCodes.PROTOCOL_ERROR; + case 1003: + return CloseCodes.CANNOT_ACCEPT; + case 1004: + // Should not be used in a close frame + // return CloseCodes.RESERVED; + return CloseCodes.PROTOCOL_ERROR; + case 1005: + // Should not be used in a close frame + // return CloseCodes.NO_STATUS_CODE; + return CloseCodes.PROTOCOL_ERROR; + case 1006: + // Should not be used in a close frame + // return CloseCodes.CLOSED_ABNORMALLY; + return CloseCodes.PROTOCOL_ERROR; + case 1007: + return CloseCodes.NOT_CONSISTENT; + case 1008: + return CloseCodes.VIOLATED_POLICY; + case 1009: + return CloseCodes.TOO_BIG; + case 1010: + return CloseCodes.NO_EXTENSION; + case 1011: + return CloseCodes.UNEXPECTED_CONDITION; + case 1012: + // Not in RFC6455 + // return CloseCodes.SERVICE_RESTART; + return CloseCodes.PROTOCOL_ERROR; + case 1013: + // Not in RFC6455 + // return CloseCodes.TRY_AGAIN_LATER; + return CloseCodes.PROTOCOL_ERROR; + case 1015: + // Should not be used in a close frame + // return CloseCodes.TLS_HANDSHAKE_FAILURE; + return CloseCodes.PROTOCOL_ERROR; + default: + return CloseCodes.PROTOCOL_ERROR; + } + } + + + static byte[] generateMask() { + // SecureRandom is not thread-safe so need to make sure only one thread + // uses it at a time. In theory, the pool could grow to the same size + // as the number of request processing threads. In reality it will be + // a lot smaller. + + // Get a SecureRandom from the pool + SecureRandom sr = randoms.poll(); + + // If one isn't available, generate a new one + if (sr == null) { + try { + sr = SecureRandom.getInstance("SHA1PRNG"); + } catch (NoSuchAlgorithmException e) { + // Fall back to platform default + sr = new SecureRandom(); + } + } + + // Generate the mask + byte[] result = new byte[4]; + sr.nextBytes(result); + + // Put the SecureRandom back in the poll + randoms.add(sr); + + return result; + } + + + static Class<?> getMessageType(MessageHandler listener) { + return Util.getGenericType(MessageHandler.class, + listener.getClass()).getClazz(); + } + + + private static Class<?> getDecoderType(Class<? extends Decoder> decoder) { + return Util.getGenericType(Decoder.class, decoder).getClazz(); + } + + + static Class<?> getEncoderType(Class<? extends Encoder> encoder) { + return Util.getGenericType(Encoder.class, encoder).getClazz(); + } + + + private static <T> TypeResult getGenericType(Class<T> type, + Class<? extends T> clazz) { + + // Look to see if this class implements the interface of interest + + // Get all the interfaces + Type[] interfaces = clazz.getGenericInterfaces(); + for (Type iface : interfaces) { + // Only need to check interfaces that use generics + if (iface instanceof ParameterizedType) { + ParameterizedType pi = (ParameterizedType) iface; + // Look for the interface of interest + if (pi.getRawType() instanceof Class) { + if (type.isAssignableFrom((Class<?>) pi.getRawType())) { + return getTypeParameter( + clazz, pi.getActualTypeArguments()[0]); + } + } + } + } + + // Interface not found on this class. Look at the superclass. + @SuppressWarnings("unchecked") + Class<? extends T> superClazz = + (Class<? extends T>) clazz.getSuperclass(); + if (superClazz == null) { + // Finished looking up the class hierarchy without finding anything + return null; + } + + TypeResult superClassTypeResult = getGenericType(type, superClazz); + int dimension = superClassTypeResult.getDimension(); + if (superClassTypeResult.getIndex() == -1 && dimension == 0) { + // Superclass implements interface and defines explicit type for + // the interface of interest + return superClassTypeResult; + } + + if (superClassTypeResult.getIndex() > -1) { + // Superclass implements interface and defines unknown type for + // the interface of interest + // Map that unknown type to the generic types defined in this class + ParameterizedType superClassType = + (ParameterizedType) clazz.getGenericSuperclass(); + TypeResult result = getTypeParameter(clazz, + superClassType.getActualTypeArguments()[ + superClassTypeResult.getIndex()]); + result.incrementDimension(superClassTypeResult.getDimension()); + if (result.getClazz() != null && result.getDimension() > 0) { + superClassTypeResult = result; + } else { + return result; + } + } + + if (superClassTypeResult.getDimension() > 0) { + StringBuilder className = new StringBuilder(); + for (int i = 0; i < dimension; i++) { + className.append('['); + } + className.append('L'); + className.append(superClassTypeResult.getClazz().getCanonicalName()); + className.append(';'); + + Class<?> arrayClazz; + try { + arrayClazz = Class.forName(className.toString()); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException(e); + } + + return new TypeResult(arrayClazz, -1, 0); + } + + // Error will be logged further up the call stack + return null; + } + + + /* + * For a generic parameter, return either the Class used or if the type + * is unknown, the index for the type in definition of the class + */ + private static TypeResult getTypeParameter(Class<?> clazz, Type argType) { + if (argType instanceof Class<?>) { + return new TypeResult((Class<?>) argType, -1, 0); + } else if (argType instanceof ParameterizedType) { + return new TypeResult((Class<?>)((ParameterizedType) argType).getRawType(), -1, 0); + } else if (argType instanceof GenericArrayType) { + Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType(); + TypeResult result = getTypeParameter(clazz, arrayElementType); + result.incrementDimension(1); + return result; + } else { + TypeVariable<?>[] tvs = clazz.getTypeParameters(); + for (int i = 0; i < tvs.length; i++) { + if (tvs[i].equals(argType)) { + return new TypeResult(null, i, 0); + } + } + return null; + } + } + + + public static boolean isPrimitive(Class<?> clazz) { + if (clazz.isPrimitive()) { + return true; + } else if(clazz.equals(Boolean.class) || + clazz.equals(Byte.class) || + clazz.equals(Character.class) || + clazz.equals(Double.class) || + clazz.equals(Float.class) || + clazz.equals(Integer.class) || + clazz.equals(Long.class) || + clazz.equals(Short.class)) { + return true; + } + return false; + } + + + public static Object coerceToType(Class<?> type, String value) { + if (type.equals(String.class)) { + return value; + } else if (type.equals(boolean.class) || type.equals(Boolean.class)) { + return Boolean.valueOf(value); + } else if (type.equals(byte.class) || type.equals(Byte.class)) { + return Byte.valueOf(value); + } else if (type.equals(char.class) || type.equals(Character.class)) { + return Character.valueOf(value.charAt(0)); + } else if (type.equals(double.class) || type.equals(Double.class)) { + return Double.valueOf(value); + } else if (type.equals(float.class) || type.equals(Float.class)) { + return Float.valueOf(value); + } else if (type.equals(int.class) || type.equals(Integer.class)) { + return Integer.valueOf(value); + } else if (type.equals(long.class) || type.equals(Long.class)) { + return Long.valueOf(value); + } else if (type.equals(short.class) || type.equals(Short.class)) { + return Short.valueOf(value); + } else { + throw new IllegalArgumentException(sm.getString( + "util.invalidType", value, type.getName())); + } + } + + + public static List<DecoderEntry> getDecoders( + List<Class<? extends Decoder>> decoderClazzes) + throws DeploymentException{ + + List<DecoderEntry> result = new ArrayList<>(); + if (decoderClazzes != null) { + for (Class<? extends Decoder> decoderClazz : decoderClazzes) { + // Need to instantiate decoder to ensure it is valid and that + // deployment can be failed if it is not + @SuppressWarnings("unused") + Decoder instance; + try { + instance = decoderClazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException( + sm.getString("pojoMethodMapping.invalidDecoder", + decoderClazz.getName()), e); + } + DecoderEntry entry = new DecoderEntry( + Util.getDecoderType(decoderClazz), decoderClazz); + result.add(entry); + } + } + + return result; + } + + + static Set<MessageHandlerResult> getMessageHandlers(Class<?> target, + MessageHandler listener, EndpointConfig endpointConfig, + Session session) { + + // Will never be more than 2 types + Set<MessageHandlerResult> results = new HashSet<>(2); + + // Simple cases - handlers already accepts one of the types expected by + // the frame handling code + if (String.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.TEXT); + results.add(result); + } else if (ByteBuffer.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.BINARY); + results.add(result); + } else if (PongMessage.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.PONG); + results.add(result); + // Handler needs wrapping and optional decoder to convert it to one of + // the types expected by the frame handling code + } else if (byte[].class.isAssignableFrom(target)) { + boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass()); + MessageHandlerResult result = new MessageHandlerResult( + whole ? new PojoMessageHandlerWholeBinary(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, true), + new Object[1], 0, true, -1, false, -1) : + new PojoMessageHandlerPartialBinary(listener, + getOnMessagePartialMethod(listener), session, + new Object[2], 0, true, 1, -1, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } else if (InputStream.class.isAssignableFrom(target)) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeBinary(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, true), + new Object[1], 0, true, -1, true, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } else if (Reader.class.isAssignableFrom(target)) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeText(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, false), + new Object[1], 0, true, -1, -1), + MessageHandlerResultType.TEXT); + results.add(result); + } else { + // Handler needs wrapping and requires decoder to convert it to one + // of the types expected by the frame handling code + DecoderMatch decoderMatch = matchDecoders(target, endpointConfig); + Method m = getOnMessageMethod(listener); + if (decoderMatch.getBinaryDecoders().size() > 0) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeBinary(listener, m, session, + endpointConfig, + decoderMatch.getBinaryDecoders(), new Object[1], + 0, false, -1, false, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } + if (decoderMatch.getTextDecoders().size() > 0) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeText(listener, m, session, + endpointConfig, + decoderMatch.getTextDecoders(), new Object[1], + 0, false, -1, -1), + MessageHandlerResultType.TEXT); + results.add(result); + } + } + + if (results.size() == 0) { + throw new IllegalArgumentException( + sm.getString("wsSession.unknownHandler", listener, target)); + } + + return results; + } + + private static List<Class<? extends Decoder>> matchDecoders(Class<?> target, + EndpointConfig endpointConfig, boolean binary) { + DecoderMatch decoderMatch = matchDecoders(target, endpointConfig); + if (binary) { + if (decoderMatch.getBinaryDecoders().size() > 0) { + return decoderMatch.getBinaryDecoders(); + } + } else if (decoderMatch.getTextDecoders().size() > 0) { + return decoderMatch.getTextDecoders(); + } + return null; + } + + private static DecoderMatch matchDecoders(Class<?> target, + EndpointConfig endpointConfig) { + DecoderMatch decoderMatch; + try { + List<Class<? extends Decoder>> decoders = + endpointConfig.getDecoders(); + List<DecoderEntry> decoderEntries = getDecoders(decoders); + decoderMatch = new DecoderMatch(target, decoderEntries); + } catch (DeploymentException e) { + throw new IllegalArgumentException(e); + } + return decoderMatch; + } + + public static void parseExtensionHeader(List<Extension> extensions, + String header) { + // The relevant ABNF for the Sec-WebSocket-Extensions is as follows: + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ; When using the quoted-string syntax variant, the value + // ; after quoted-string unescaping MUST conform to the + // ; 'token' ABNF. + // + // The limiting of parameter values to tokens or "quoted tokens" makes + // the parsing of the header significantly simpler and allows a number + // of short-cuts to be taken. + + // Step one, split the header into individual extensions using ',' as a + // separator + String unparsedExtensions[] = header.split(","); + for (String unparsedExtension : unparsedExtensions) { + // Step two, split the extension into the registered name and + // parameter/value pairs using ';' as a separator + String unparsedParameters[] = unparsedExtension.split(";"); + WsExtension extension = new WsExtension(unparsedParameters[0].trim()); + + for (int i = 1; i < unparsedParameters.length; i++) { + int equalsPos = unparsedParameters[i].indexOf('='); + String name; + String value; + if (equalsPos == -1) { + name = unparsedParameters[i].trim(); + value = null; + } else { + name = unparsedParameters[i].substring(0, equalsPos).trim(); + value = unparsedParameters[i].substring(equalsPos + 1).trim(); + int len = value.length(); + if (len > 1) { + if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') { + value = value.substring(1, value.length() - 1); + } + } + } + // Make sure value doesn't contain any of the delimiters since + // that would indicate something went wrong + if (containsDelims(name) || containsDelims(value)) { + throw new IllegalArgumentException(sm.getString( + "util.notToken", name, value)); + } + if (value != null && + (value.indexOf(',') > -1 || value.indexOf(';') > -1 || + value.indexOf('\"') > -1 || value.indexOf('=') > -1)) { + throw new IllegalArgumentException(sm.getString("", value)); + } + extension.addParameter(new WsExtensionParameter(name, value)); + } + extensions.add(extension); + } + } + + + private static boolean containsDelims(String input) { + if (input == null || input.length() == 0) { + return false; + } + for (char c : input.toCharArray()) { + switch (c) { + case ',': + case ';': + case '\"': + case '=': + return true; + default: + // NO_OP + } + + } + return false; + } + + private static Method getOnMessageMethod(MessageHandler listener) { + try { + return listener.getClass().getMethod("onMessage", Object.class); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + sm.getString("util.invalidMessageHandler"), e); + } + } + + private static Method getOnMessagePartialMethod(MessageHandler listener) { + try { + return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + sm.getString("util.invalidMessageHandler"), e); + } + } + + + public static class DecoderMatch { + + private final List<Class<? extends Decoder>> textDecoders = + new ArrayList<>(); + private final List<Class<? extends Decoder>> binaryDecoders = + new ArrayList<>(); + private final Class<?> target; + + public DecoderMatch(Class<?> target, List<DecoderEntry> decoderEntries) { + this.target = target; + for (DecoderEntry decoderEntry : decoderEntries) { + if (decoderEntry.getClazz().isAssignableFrom(target)) { + if (Binary.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + binaryDecoders.add(decoderEntry.getDecoderClazz()); + // willDecode() method means this decoder may or may not + // decode a message so need to carry on checking for + // other matches + } else if (BinaryStream.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + binaryDecoders.add(decoderEntry.getDecoderClazz()); + // Stream decoders have to process the message so no + // more decoders can be matched + break; + } else if (Text.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + textDecoders.add(decoderEntry.getDecoderClazz()); + // willDecode() method means this decoder may or may not + // decode a message so need to carry on checking for + // other matches + } else if (TextStream.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + textDecoders.add(decoderEntry.getDecoderClazz()); + // Stream decoders have to process the message so no + // more decoders can be matched + break; + } else { + throw new IllegalArgumentException( + sm.getString("util.unknownDecoderType")); + } + } + } + } + + + public List<Class<? extends Decoder>> getTextDecoders() { + return textDecoders; + } + + + public List<Class<? extends Decoder>> getBinaryDecoders() { + return binaryDecoders; + } + + + public Class<?> getTarget() { + return target; + } + + + public boolean hasMatches() { + return (textDecoders.size() > 0) || (binaryDecoders.size() > 0); + } + } + + + private static class TypeResult { + private final Class<?> clazz; + private final int index; + private int dimension; + + public TypeResult(Class<?> clazz, int index, int dimension) { + this.clazz= clazz; + this.index = index; + this.dimension = dimension; + } + + public Class<?> getClazz() { + return clazz; + } + + public int getIndex() { + return index; + } + + public int getDimension() { + return dimension; + } + + public void incrementDimension(int inc) { + dimension += inc; + } + } +} diff --git a/src/java/nginx/unit/websocket/WrappedMessageHandler.java b/src/java/nginx/unit/websocket/WrappedMessageHandler.java new file mode 100644 index 00000000..2557a73e --- /dev/null +++ b/src/java/nginx/unit/websocket/WrappedMessageHandler.java @@ -0,0 +1,25 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.MessageHandler; + +public interface WrappedMessageHandler { + long getMaxMessageSize(); + + MessageHandler getWrappedHandler(); +} diff --git a/src/java/nginx/unit/websocket/WsContainerProvider.java b/src/java/nginx/unit/websocket/WsContainerProvider.java new file mode 100644 index 00000000..f8a404a1 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsContainerProvider.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.ContainerProvider; +import javax.websocket.WebSocketContainer; + +public class WsContainerProvider extends ContainerProvider { + + @Override + protected WebSocketContainer getContainer() { + return new WsWebSocketContainer(); + } +} diff --git a/src/java/nginx/unit/websocket/WsExtension.java b/src/java/nginx/unit/websocket/WsExtension.java new file mode 100644 index 00000000..3846feb1 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsExtension.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.Extension; + +public class WsExtension implements Extension { + + private final String name; + private final List<Parameter> parameters = new ArrayList<>(); + + WsExtension(String name) { + this.name = name; + } + + void addParameter(Parameter parameter) { + parameters.add(parameter); + } + + @Override + public String getName() { + return name; + } + + @Override + public List<Parameter> getParameters() { + return parameters; + } +} diff --git a/src/java/nginx/unit/websocket/WsExtensionParameter.java b/src/java/nginx/unit/websocket/WsExtensionParameter.java new file mode 100644 index 00000000..9b82f1c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsExtensionParameter.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import javax.websocket.Extension.Parameter; + +public class WsExtensionParameter implements Parameter { + + private final String name; + private final String value; + + WsExtensionParameter(String name, String value) { + this.name = name; + this.value = value; + } + + @Override + public String getName() { + return name; + } + + @Override + public String getValue() { + return value; + } +} diff --git a/src/java/nginx/unit/websocket/WsFrameBase.java b/src/java/nginx/unit/websocket/WsFrameBase.java new file mode 100644 index 00000000..06d20bf4 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsFrameBase.java @@ -0,0 +1,1010 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; +import java.util.List; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.PongMessage; + +import org.apache.juli.logging.Log; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.buf.Utf8Decoder; +import org.apache.tomcat.util.res.StringManager; + +/** + * Takes the ServletInputStream, processes the WebSocket frames it contains and + * extracts the messages. WebSocket Pings received will be responded to + * automatically without any action required by the application. + */ +public abstract class WsFrameBase { + + private static final StringManager sm = StringManager.getManager(WsFrameBase.class); + + // Connection level attributes + protected final WsSession wsSession; + protected final ByteBuffer inputBuffer; + private final Transformation transformation; + + // Attributes for control messages + // Control messages can appear in the middle of other messages so need + // separate attributes + private final ByteBuffer controlBufferBinary = ByteBuffer.allocate(125); + private final CharBuffer controlBufferText = CharBuffer.allocate(125); + + // Attributes of the current message + private final CharsetDecoder utf8DecoderControl = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + private boolean continuationExpected = false; + private boolean textMessage = false; + private ByteBuffer messageBufferBinary; + private CharBuffer messageBufferText; + // Cache the message handler in force when the message starts so it is used + // consistently for the entire message + private MessageHandler binaryMsgHandler = null; + private MessageHandler textMsgHandler = null; + + // Attributes of the current frame + private boolean fin = false; + private int rsv = 0; + private byte opCode = 0; + private final byte[] mask = new byte[4]; + private int maskIndex = 0; + private long payloadLength = 0; + private volatile long payloadWritten = 0; + + // Attributes tracking state + private volatile State state = State.NEW_FRAME; + private volatile boolean open = true; + + private static final AtomicReferenceFieldUpdater<WsFrameBase, ReadState> READ_STATE_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(WsFrameBase.class, ReadState.class, "readState"); + private volatile ReadState readState = ReadState.WAITING; + + public WsFrameBase(WsSession wsSession, Transformation transformation) { + inputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + inputBuffer.position(0).limit(0); + messageBufferBinary = ByteBuffer.allocate(wsSession.getMaxBinaryMessageBufferSize()); + messageBufferText = CharBuffer.allocate(wsSession.getMaxTextMessageBufferSize()); + this.wsSession = wsSession; + Transformation finalTransformation; + if (isMasked()) { + finalTransformation = new UnmaskTransformation(); + } else { + finalTransformation = new NoopTransformation(); + } + if (transformation == null) { + this.transformation = finalTransformation; + } else { + transformation.setNext(finalTransformation); + this.transformation = transformation; + } + } + + + protected void processInputBuffer() throws IOException { + while (!isSuspended()) { + wsSession.updateLastActive(); + if (state == State.NEW_FRAME) { + if (!processInitialHeader()) { + break; + } + // If a close frame has been received, no further data should + // have seen + if (!open) { + throw new IOException(sm.getString("wsFrame.closed")); + } + } + if (state == State.PARTIAL_HEADER) { + if (!processRemainingHeader()) { + break; + } + } + if (state == State.DATA) { + if (!processData()) { + break; + } + } + } + } + + + /** + * @return <code>true</code> if sufficient data was present to process all + * of the initial header + */ + private boolean processInitialHeader() throws IOException { + // Need at least two bytes of data to do this + if (inputBuffer.remaining() < 2) { + return false; + } + int b = inputBuffer.get(); + fin = (b & 0x80) != 0; + rsv = (b & 0x70) >>> 4; + opCode = (byte) (b & 0x0F); + if (!transformation.validateRsv(rsv, opCode)) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.wrongRsv", Integer.valueOf(rsv), Integer.valueOf(opCode)))); + } + + if (Util.isControl(opCode)) { + if (!fin) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlFragmented"))); + } + if (opCode != Constants.OPCODE_PING && + opCode != Constants.OPCODE_PONG && + opCode != Constants.OPCODE_CLOSE) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + } else { + if (continuationExpected) { + if (!Util.isContinuation(opCode)) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.noContinuation"))); + } + } else { + try { + if (opCode == Constants.OPCODE_BINARY) { + // New binary message + textMessage = false; + int size = wsSession.getMaxBinaryMessageBufferSize(); + if (size != messageBufferBinary.capacity()) { + messageBufferBinary = ByteBuffer.allocate(size); + } + binaryMsgHandler = wsSession.getBinaryMessageHandler(); + textMsgHandler = null; + } else if (opCode == Constants.OPCODE_TEXT) { + // New text message + textMessage = true; + int size = wsSession.getMaxTextMessageBufferSize(); + if (size != messageBufferText.capacity()) { + messageBufferText = CharBuffer.allocate(size); + } + binaryMsgHandler = null; + textMsgHandler = wsSession.getTextMessageHandler(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + } catch (IllegalStateException ise) { + // Thrown if the session is already closed + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.sessionClosed"))); + } + } + continuationExpected = !fin; + } + b = inputBuffer.get(); + // Client data must be masked + if ((b & 0x80) == 0 && isMasked()) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.notMasked"))); + } + payloadLength = b & 0x7F; + state = State.PARTIAL_HEADER; + if (getLog().isDebugEnabled()) { + getLog().debug(sm.getString("wsFrame.partialHeaderComplete", Boolean.toString(fin), + Integer.toString(rsv), Integer.toString(opCode), Long.toString(payloadLength))); + } + return true; + } + + + protected abstract boolean isMasked(); + protected abstract Log getLog(); + + + /** + * @return <code>true</code> if sufficient data was present to complete the + * processing of the header + */ + private boolean processRemainingHeader() throws IOException { + // Ignore the 2 bytes already read. 4 for the mask + int headerLength; + if (isMasked()) { + headerLength = 4; + } else { + headerLength = 0; + } + // Add additional bytes depending on length + if (payloadLength == 126) { + headerLength += 2; + } else if (payloadLength == 127) { + headerLength += 8; + } + if (inputBuffer.remaining() < headerLength) { + return false; + } + // Calculate new payload length if necessary + if (payloadLength == 126) { + payloadLength = byteArrayToLong(inputBuffer.array(), + inputBuffer.arrayOffset() + inputBuffer.position(), 2); + inputBuffer.position(inputBuffer.position() + 2); + } else if (payloadLength == 127) { + payloadLength = byteArrayToLong(inputBuffer.array(), + inputBuffer.arrayOffset() + inputBuffer.position(), 8); + inputBuffer.position(inputBuffer.position() + 8); + } + if (Util.isControl(opCode)) { + if (payloadLength > 125) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlPayloadTooBig", Long.valueOf(payloadLength)))); + } + if (!fin) { + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.controlNoFin"))); + } + } + if (isMasked()) { + inputBuffer.get(mask, 0, 4); + } + state = State.DATA; + return true; + } + + + private boolean processData() throws IOException { + boolean result; + if (Util.isControl(opCode)) { + result = processDataControl(); + } else if (textMessage) { + if (textMsgHandler == null) { + result = swallowInput(); + } else { + result = processDataText(); + } + } else { + if (binaryMsgHandler == null) { + result = swallowInput(); + } else { + result = processDataBinary(); + } + } + checkRoomPayload(); + return result; + } + + + private boolean processDataControl() throws IOException { + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, controlBufferBinary); + if (TransformationResult.UNDERFLOW.equals(tr)) { + return false; + } + // Control messages have fixed message size so + // TransformationResult.OVERFLOW is not possible here + + controlBufferBinary.flip(); + if (opCode == Constants.OPCODE_CLOSE) { + open = false; + String reason = null; + int code = CloseCodes.NORMAL_CLOSURE.getCode(); + if (controlBufferBinary.remaining() == 1) { + controlBufferBinary.clear(); + // Payload must be zero or 2+ bytes long + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.oneByteCloseCode"))); + } + if (controlBufferBinary.remaining() > 1) { + code = controlBufferBinary.getShort(); + if (controlBufferBinary.remaining() > 0) { + CoderResult cr = utf8DecoderControl.decode(controlBufferBinary, + controlBufferText, true); + if (cr.isError()) { + controlBufferBinary.clear(); + controlBufferText.clear(); + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidUtf8Close"))); + } + // There will be no overflow as the output buffer is big + // enough. There will be no underflow as all the data is + // passed to the decoder in a single call. + controlBufferText.flip(); + reason = controlBufferText.toString(); + } + } + wsSession.onClose(new CloseReason(Util.getCloseCode(code), reason)); + } else if (opCode == Constants.OPCODE_PING) { + if (wsSession.isOpen()) { + wsSession.getBasicRemote().sendPong(controlBufferBinary); + } + } else if (opCode == Constants.OPCODE_PONG) { + MessageHandler.Whole<PongMessage> mhPong = wsSession.getPongMessageHandler(); + if (mhPong != null) { + try { + mhPong.onMessage(new WsPongMessage(controlBufferBinary)); + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + controlBufferBinary.clear(); + } + } + } else { + // Should have caught this earlier but just in case... + controlBufferBinary.clear(); + throw new WsIOException(new CloseReason( + CloseCodes.PROTOCOL_ERROR, + sm.getString("wsFrame.invalidOpCode", Integer.valueOf(opCode)))); + } + controlBufferBinary.clear(); + newFrame(); + return true; + } + + + @SuppressWarnings("unchecked") + protected void sendMessageText(boolean last) throws WsIOException { + if (textMsgHandler instanceof WrappedMessageHandler) { + long maxMessageSize = ((WrappedMessageHandler) textMsgHandler).getMaxMessageSize(); + if (maxMessageSize > -1 && messageBufferText.remaining() > maxMessageSize) { + throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.messageTooBig", + Long.valueOf(messageBufferText.remaining()), + Long.valueOf(maxMessageSize)))); + } + } + + try { + if (textMsgHandler instanceof MessageHandler.Partial<?>) { + ((MessageHandler.Partial<String>) textMsgHandler) + .onMessage(messageBufferText.toString(), last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole<String>) textMsgHandler) + .onMessage(messageBufferText.toString()); + } + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + messageBufferText.clear(); + } + } + + + private boolean processDataText() throws IOException { + // Copy the available data to the buffer + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + while (!TransformationResult.END_OF_FRAME.equals(tr)) { + // Frame not complete - we ran out of something + // Convert bytes to UTF-8 + messageBufferBinary.flip(); + while (true) { + CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText, + false); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow()) { + // Compact what we have to create as much space as possible + messageBufferBinary.compact(); + + // Need more input + // What did we run out of? + if (TransformationResult.OVERFLOW.equals(tr)) { + // Ran out of message buffer - exit inner loop and + // refill + break; + } else { + // TransformationResult.UNDERFLOW + // Ran out of input data - get some more + return false; + } + } + } + // Read more input data + tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + } + + messageBufferBinary.flip(); + boolean last = false; + // Frame is fully received + // Convert bytes to UTF-8 + while (true) { + CoderResult cr = utf8DecoderMessage.decode(messageBufferBinary, messageBufferText, + last); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow() && !last) { + // End of frame and possible message as well. + + if (continuationExpected) { + // If partial messages are supported, send what we have + // managed to decode + if (usePartial()) { + messageBufferText.flip(); + sendMessageText(false); + messageBufferText.clear(); + } + messageBufferBinary.compact(); + newFrame(); + // Process next frame + return true; + } else { + // Make sure coder has flushed all output + last = true; + } + } else { + // End of message + messageBufferText.flip(); + sendMessageText(true); + + newMessage(); + return true; + } + } + } + + + private boolean processDataBinary() throws IOException { + // Copy the available data to the buffer + TransformationResult tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + while (!TransformationResult.END_OF_FRAME.equals(tr)) { + // Frame not complete - what did we run out of? + if (TransformationResult.UNDERFLOW.equals(tr)) { + // Ran out of input data - get some more + return false; + } + + // Ran out of message buffer - flush it + if (!usePartial()) { + CloseReason cr = new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.bufferTooSmall", + Integer.valueOf(messageBufferBinary.capacity()), + Long.valueOf(payloadLength))); + throw new WsIOException(cr); + } + messageBufferBinary.flip(); + ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit()); + copy.put(messageBufferBinary); + copy.flip(); + sendMessageBinary(copy, false); + messageBufferBinary.clear(); + // Read more data + tr = transformation.getMoreData(opCode, fin, rsv, messageBufferBinary); + } + + // Frame is fully received + // Send the message if either: + // - partial messages are supported + // - the message is complete + if (usePartial() || !continuationExpected) { + messageBufferBinary.flip(); + ByteBuffer copy = ByteBuffer.allocate(messageBufferBinary.limit()); + copy.put(messageBufferBinary); + copy.flip(); + sendMessageBinary(copy, !continuationExpected); + messageBufferBinary.clear(); + } + + if (continuationExpected) { + // More data for this message expected, start a new frame + newFrame(); + } else { + // Message is complete, start a new message + newMessage(); + } + + return true; + } + + + private void handleThrowableOnSend(Throwable t) throws WsIOException { + ExceptionUtils.handleThrowable(t); + wsSession.getLocal().onError(wsSession, t); + CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, + sm.getString("wsFrame.ioeTriggeredClose")); + throw new WsIOException(cr); + } + + + @SuppressWarnings("unchecked") + protected void sendMessageBinary(ByteBuffer msg, boolean last) throws WsIOException { + if (binaryMsgHandler instanceof WrappedMessageHandler) { + long maxMessageSize = ((WrappedMessageHandler) binaryMsgHandler).getMaxMessageSize(); + if (maxMessageSize > -1 && msg.remaining() > maxMessageSize) { + throw new WsIOException(new CloseReason(CloseCodes.TOO_BIG, + sm.getString("wsFrame.messageTooBig", + Long.valueOf(msg.remaining()), + Long.valueOf(maxMessageSize)))); + } + } + try { + if (binaryMsgHandler instanceof MessageHandler.Partial<?>) { + ((MessageHandler.Partial<ByteBuffer>) binaryMsgHandler).onMessage(msg, last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole<ByteBuffer>) binaryMsgHandler).onMessage(msg); + } + } catch (Throwable t) { + handleThrowableOnSend(t); + } + } + + + private void newMessage() { + messageBufferBinary.clear(); + messageBufferText.clear(); + utf8DecoderMessage.reset(); + continuationExpected = false; + newFrame(); + } + + + private void newFrame() { + if (inputBuffer.remaining() == 0) { + inputBuffer.position(0).limit(0); + } + + maskIndex = 0; + payloadWritten = 0; + state = State.NEW_FRAME; + + // These get reset in processInitialHeader() + // fin, rsv, opCode, payloadLength, mask + + checkRoomHeaders(); + } + + + private void checkRoomHeaders() { + // Is the start of the current frame too near the end of the input + // buffer? + if (inputBuffer.capacity() - inputBuffer.position() < 131) { + // Limit based on a control frame with a full payload + makeRoom(); + } + } + + + private void checkRoomPayload() { + if (inputBuffer.capacity() - inputBuffer.position() - payloadLength + payloadWritten < 0) { + makeRoom(); + } + } + + + private void makeRoom() { + inputBuffer.compact(); + inputBuffer.flip(); + } + + + private boolean usePartial() { + if (Util.isControl(opCode)) { + return false; + } else if (textMessage) { + return textMsgHandler instanceof MessageHandler.Partial; + } else { + // Must be binary + return binaryMsgHandler instanceof MessageHandler.Partial; + } + } + + + private boolean swallowInput() { + long toSkip = Math.min(payloadLength - payloadWritten, inputBuffer.remaining()); + inputBuffer.position(inputBuffer.position() + (int) toSkip); + payloadWritten += toSkip; + if (payloadWritten == payloadLength) { + if (continuationExpected) { + newFrame(); + } else { + newMessage(); + } + return true; + } else { + return false; + } + } + + + protected static long byteArrayToLong(byte[] b, int start, int len) throws IOException { + if (len > 8) { + throw new IOException(sm.getString("wsFrame.byteToLongFail", Long.valueOf(len))); + } + int shift = 0; + long result = 0; + for (int i = start + len - 1; i >= start; i--) { + result = result + ((b[i] & 0xFF) << shift); + shift += 8; + } + return result; + } + + + protected boolean isOpen() { + return open; + } + + + protected Transformation getTransformation() { + return transformation; + } + + + private enum State { + NEW_FRAME, PARTIAL_HEADER, DATA + } + + + /** + * WAITING - not suspended + * Server case: waiting for a notification that data + * is ready to be read from the socket, the socket is + * registered to the poller + * Client case: data has been read from the socket and + * is waiting for data to be processed + * PROCESSING - not suspended + * Server case: reading from the socket and processing + * the data + * Client case: processing the data if such has + * already been read and more data will be read from + * the socket + * SUSPENDING_WAIT - suspended, a call to suspend() was made while in + * WAITING state. A call to resume() will do nothing + * and will transition to WAITING state + * SUSPENDING_PROCESS - suspended, a call to suspend() was made while in + * PROCESSING state. A call to resume() will do + * nothing and will transition to PROCESSING state + * SUSPENDED - suspended + * Server case: processing data finished + * (SUSPENDING_PROCESS) / a notification was received + * that data is ready to be read from the socket + * (SUSPENDING_WAIT), socket is not registered to the + * poller + * Client case: processing data finished + * (SUSPENDING_PROCESS) / data has been read from the + * socket and is available for processing + * (SUSPENDING_WAIT) + * A call to resume() will: + * Server case: register the socket to the poller + * Client case: resume data processing + * CLOSING - not suspended, a close will be send + * + * <pre> + * resume data to be resume + * no action processed no action + * |---------------| |---------------| |----------| + * | v | v v | + * | |----------WAITING --------PROCESSING----| | + * | | ^ processing | | + * | | | finished | | + * | | | | | + * | suspend | suspend | + * | | | | | + * | | resume | | + * | | register socket to poller (server) | | + * | | resume data processing (client) | | + * | | | | | + * | v | v | + * SUSPENDING_WAIT | SUSPENDING_PROCESS + * | | | + * | data available | processing finished | + * |------------- SUSPENDED ----------------------| + * </pre> + */ + protected enum ReadState { + WAITING (false), + PROCESSING (false), + SUSPENDING_WAIT (true), + SUSPENDING_PROCESS(true), + SUSPENDED (true), + CLOSING (false); + + private final boolean isSuspended; + + ReadState(boolean isSuspended) { + this.isSuspended = isSuspended; + } + + public boolean isSuspended() { + return isSuspended; + } + } + + public void suspend() { + while (true) { + switch (readState) { + case WAITING: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.WAITING, + ReadState.SUSPENDING_WAIT)) { + continue; + } + return; + case PROCESSING: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.PROCESSING, + ReadState.SUSPENDING_PROCESS)) { + continue; + } + return; + case SUSPENDING_WAIT: + if (readState != ReadState.SUSPENDING_WAIT) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.suspendRequested")); + } + } + return; + case SUSPENDING_PROCESS: + if (readState != ReadState.SUSPENDING_PROCESS) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.suspendRequested")); + } + } + return; + case SUSPENDED: + if (readState != ReadState.SUSPENDED) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadySuspended")); + } + } + return; + case CLOSING: + return; + default: + throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state)); + } + } + } + + public void resume() { + while (true) { + switch (readState) { + case WAITING: + if (readState != ReadState.WAITING) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadyResumed")); + } + } + return; + case PROCESSING: + if (readState != ReadState.PROCESSING) { + continue; + } else { + if (getLog().isWarnEnabled()) { + getLog().warn(sm.getString("wsFrame.alreadyResumed")); + } + } + return; + case SUSPENDING_WAIT: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_WAIT, + ReadState.WAITING)) { + continue; + } + return; + case SUSPENDING_PROCESS: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDING_PROCESS, + ReadState.PROCESSING)) { + continue; + } + return; + case SUSPENDED: + if (!READ_STATE_UPDATER.compareAndSet(this, ReadState.SUSPENDED, + ReadState.WAITING)) { + continue; + } + resumeProcessing(); + return; + case CLOSING: + return; + default: + throw new IllegalStateException(sm.getString("wsFrame.illegalReadState", state)); + } + } + } + + protected boolean isSuspended() { + return readState.isSuspended(); + } + + protected ReadState getReadState() { + return readState; + } + + protected void changeReadState(ReadState newState) { + READ_STATE_UPDATER.set(this, newState); + } + + protected boolean changeReadState(ReadState oldState, ReadState newState) { + return READ_STATE_UPDATER.compareAndSet(this, oldState, newState); + } + + /** + * This method will be invoked when the read operation is resumed. + * As the suspend of the read operation can be invoked at any time, when + * implementing this method one should consider that there might still be + * data remaining into the internal buffers that needs to be processed + * before reading again from the socket. + */ + protected abstract void resumeProcessing(); + + + private abstract class TerminalTransformation implements Transformation { + + @Override + public boolean validateRsvBits(int i) { + // Terminal transformations don't use RSV bits and there is no next + // transformation so always return true. + return true; + } + + @Override + public Extension getExtensionResponse() { + // Return null since terminal transformations are not extensions + return null; + } + + @Override + public void setNext(Transformation t) { + // NO-OP since this is the terminal transformation + } + + /** + * {@inheritDoc} + * <p> + * Anything other than a value of zero for rsv is invalid. + */ + @Override + public boolean validateRsv(int rsv, byte opCode) { + return rsv == 0; + } + + @Override + public void close() { + // NO-OP for the terminal transformations + } + } + + + /** + * For use by the client implementation that needs to obtain payload data + * without the need for unmasking. + */ + private final class NoopTransformation extends TerminalTransformation { + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, + ByteBuffer dest) { + // opCode is ignored as the transformation is the same for all + // opCodes + // rsv is ignored as it known to be zero at this point + long toWrite = Math.min(payloadLength - payloadWritten, inputBuffer.remaining()); + toWrite = Math.min(toWrite, dest.remaining()); + + int orgLimit = inputBuffer.limit(); + inputBuffer.limit(inputBuffer.position() + (int) toWrite); + dest.put(inputBuffer); + inputBuffer.limit(orgLimit); + payloadWritten += toWrite; + + if (payloadWritten == payloadLength) { + return TransformationResult.END_OF_FRAME; + } else if (inputBuffer.remaining() == 0) { + return TransformationResult.UNDERFLOW; + } else { + // !dest.hasRemaining() + return TransformationResult.OVERFLOW; + } + } + + + @Override + public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) { + // TODO Masking should move to this method + // NO-OP send so simply return the message unchanged. + return messageParts; + } + } + + + /** + * For use by the server implementation that needs to obtain payload data + * and unmask it before any further processing. + */ + private final class UnmaskTransformation extends TerminalTransformation { + + @Override + public TransformationResult getMoreData(byte opCode, boolean fin, int rsv, + ByteBuffer dest) { + // opCode is ignored as the transformation is the same for all + // opCodes + // rsv is ignored as it known to be zero at this point + while (payloadWritten < payloadLength && inputBuffer.remaining() > 0 && + dest.hasRemaining()) { + byte b = (byte) ((inputBuffer.get() ^ mask[maskIndex]) & 0xFF); + maskIndex++; + if (maskIndex == 4) { + maskIndex = 0; + } + payloadWritten++; + dest.put(b); + } + if (payloadWritten == payloadLength) { + return TransformationResult.END_OF_FRAME; + } else if (inputBuffer.remaining() == 0) { + return TransformationResult.UNDERFLOW; + } else { + // !dest.hasRemaining() + return TransformationResult.OVERFLOW; + } + } + + @Override + public List<MessagePart> sendMessagePart(List<MessagePart> messageParts) { + // NO-OP send so simply return the message unchanged. + return messageParts; + } + } +} diff --git a/src/java/nginx/unit/websocket/WsFrameClient.java b/src/java/nginx/unit/websocket/WsFrameClient.java new file mode 100644 index 00000000..3174c766 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsFrameClient.java @@ -0,0 +1,228 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +public class WsFrameClient extends WsFrameBase { + + private final Log log = LogFactory.getLog(WsFrameClient.class); // must not be static + private static final StringManager sm = StringManager.getManager(WsFrameClient.class); + + private final AsyncChannelWrapper channel; + private final CompletionHandler<Integer, Void> handler; + // Not final as it may need to be re-sized + private volatile ByteBuffer response; + + public WsFrameClient(ByteBuffer response, AsyncChannelWrapper channel, WsSession wsSession, + Transformation transformation) { + super(wsSession, transformation); + this.response = response; + this.channel = channel; + this.handler = new WsFrameClientCompletionHandler(); + } + + + void startInputProcessing() { + try { + processSocketRead(); + } catch (IOException e) { + close(e); + } + } + + + private void processSocketRead() throws IOException { + while (true) { + switch (getReadState()) { + case WAITING: + if (!changeReadState(ReadState.WAITING, ReadState.PROCESSING)) { + continue; + } + while (response.hasRemaining()) { + if (isSuspended()) { + if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) { + continue; + } + // There is still data available in the response buffer + // Return here so that the response buffer will not be + // cleared and there will be no data read from the + // socket. Thus when the read operation is resumed first + // the data left in the response buffer will be consumed + // and then a new socket read will be performed + return; + } + inputBuffer.mark(); + inputBuffer.position(inputBuffer.limit()).limit(inputBuffer.capacity()); + + int toCopy = Math.min(response.remaining(), inputBuffer.remaining()); + + // Copy remaining bytes read in HTTP phase to input buffer used by + // frame processing + + int orgLimit = response.limit(); + response.limit(response.position() + toCopy); + inputBuffer.put(response); + response.limit(orgLimit); + + inputBuffer.limit(inputBuffer.position()).reset(); + + // Process the data we have + processInputBuffer(); + } + response.clear(); + + // Get some more data + if (isOpen()) { + channel.read(response, null, handler); + } else { + changeReadState(ReadState.CLOSING); + } + return; + case SUSPENDING_WAIT: + if (!changeReadState(ReadState.SUSPENDING_WAIT, ReadState.SUSPENDED)) { + continue; + } + return; + default: + throw new IllegalStateException( + sm.getString("wsFrameServer.illegalReadState", getReadState())); + } + } + } + + + private final void close(Throwable t) { + changeReadState(ReadState.CLOSING); + CloseReason cr; + if (t instanceof WsIOException) { + cr = ((WsIOException) t).getCloseReason(); + } else { + cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, t.getMessage()); + } + + try { + wsSession.close(cr); + } catch (IOException ignore) { + // Ignore + } + } + + + @Override + protected boolean isMasked() { + // Data is from the server so it is not masked + return false; + } + + + @Override + protected Log getLog() { + return log; + } + + private class WsFrameClientCompletionHandler implements CompletionHandler<Integer, Void> { + + @Override + public void completed(Integer result, Void attachment) { + if (result.intValue() == -1) { + // BZ 57762. A dropped connection will get reported as EOF + // rather than as an error so handle it here. + if (isOpen()) { + // No close frame was received + close(new EOFException()); + } + // No data to process + return; + } + response.flip(); + doResumeProcessing(true); + } + + @Override + public void failed(Throwable exc, Void attachment) { + if (exc instanceof ReadBufferOverflowException) { + // response will be empty if this exception is thrown + response = ByteBuffer + .allocate(((ReadBufferOverflowException) exc).getMinBufferSize()); + response.flip(); + doResumeProcessing(false); + } else { + close(exc); + } + } + + private void doResumeProcessing(boolean checkOpenOnError) { + while (true) { + switch (getReadState()) { + case PROCESSING: + if (!changeReadState(ReadState.PROCESSING, ReadState.WAITING)) { + continue; + } + resumeProcessing(checkOpenOnError); + return; + case SUSPENDING_PROCESS: + if (!changeReadState(ReadState.SUSPENDING_PROCESS, ReadState.SUSPENDED)) { + continue; + } + return; + default: + throw new IllegalStateException( + sm.getString("wsFrame.illegalReadState", getReadState())); + } + } + } + } + + + @Override + protected void resumeProcessing() { + resumeProcessing(true); + } + + private void resumeProcessing(boolean checkOpenOnError) { + try { + processSocketRead(); + } catch (IOException e) { + if (checkOpenOnError) { + // Only send a close message on an IOException if the client + // has not yet received a close control message from the server + // as the IOException may be in response to the client + // continuing to send a message after the server sent a close + // control message. + if (isOpen()) { + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsFrameClient.ioe"), e); + } + close(e); + } + } else { + close(e); + } + } + } +} diff --git a/src/java/nginx/unit/websocket/WsHandshakeResponse.java b/src/java/nginx/unit/websocket/WsHandshakeResponse.java new file mode 100644 index 00000000..6e57ffd5 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsHandshakeResponse.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.websocket.HandshakeResponse; + +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; + +/** + * Represents the response to a WebSocket handshake. + */ +public class WsHandshakeResponse implements HandshakeResponse { + + private final Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>(); + + + public WsHandshakeResponse() { + } + + + public WsHandshakeResponse(Map<String,List<String>> headers) { + for (Entry<String,List<String>> entry : headers.entrySet()) { + if (this.headers.containsKey(entry.getKey())) { + this.headers.get(entry.getKey()).addAll(entry.getValue()); + } else { + List<String> values = new ArrayList<>(entry.getValue()); + this.headers.put(entry.getKey(), values); + } + } + } + + + @Override + public Map<String,List<String>> getHeaders() { + return headers; + } +} diff --git a/src/java/nginx/unit/websocket/WsIOException.java b/src/java/nginx/unit/websocket/WsIOException.java new file mode 100644 index 00000000..0362dc1d --- /dev/null +++ b/src/java/nginx/unit/websocket/WsIOException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; + +import javax.websocket.CloseReason; + +/** + * Allows the WebSocket implementation to throw an {@link IOException} that + * includes a {@link CloseReason} specific to the error that can be passed back + * to the client. + */ +public class WsIOException extends IOException { + + private static final long serialVersionUID = 1L; + + private final CloseReason closeReason; + + public WsIOException(CloseReason closeReason) { + this.closeReason = closeReason; + } + + public CloseReason getCloseReason() { + return closeReason; + } +} diff --git a/src/java/nginx/unit/websocket/WsPongMessage.java b/src/java/nginx/unit/websocket/WsPongMessage.java new file mode 100644 index 00000000..531bcda9 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsPongMessage.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; + +import javax.websocket.PongMessage; + +public class WsPongMessage implements PongMessage { + + private final ByteBuffer applicationData; + + + public WsPongMessage(ByteBuffer applicationData) { + byte[] dst = new byte[applicationData.limit()]; + applicationData.get(dst); + this.applicationData = ByteBuffer.wrap(dst); + } + + + @Override + public ByteBuffer getApplicationData() { + return applicationData; + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java b/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java new file mode 100644 index 00000000..0ea20795 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointAsync.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.nio.ByteBuffer; +import java.util.concurrent.Future; + +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; + +public class WsRemoteEndpointAsync extends WsRemoteEndpointBase + implements RemoteEndpoint.Async { + + WsRemoteEndpointAsync(WsRemoteEndpointImplBase base) { + super(base); + } + + + @Override + public long getSendTimeout() { + return base.getSendTimeout(); + } + + + @Override + public void setSendTimeout(long timeout) { + base.setSendTimeout(timeout); + } + + + @Override + public void sendText(String text, SendHandler completion) { + base.sendStringByCompletion(text, completion); + } + + + @Override + public Future<Void> sendText(String text) { + return base.sendStringByFuture(text); + } + + + @Override + public Future<Void> sendBinary(ByteBuffer data) { + return base.sendBytesByFuture(data); + } + + + @Override + public void sendBinary(ByteBuffer data, SendHandler completion) { + base.sendBytesByCompletion(data, completion); + } + + + @Override + public Future<Void> sendObject(Object obj) { + return base.sendObjectByFuture(obj); + } + + + @Override + public void sendObject(Object obj, SendHandler completion) { + base.sendObjectByCompletion(obj, completion); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java new file mode 100644 index 00000000..21cb2040 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointBase.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; + +import javax.websocket.RemoteEndpoint; + +public abstract class WsRemoteEndpointBase implements RemoteEndpoint { + + protected final WsRemoteEndpointImplBase base; + + + WsRemoteEndpointBase(WsRemoteEndpointImplBase base) { + this.base = base; + } + + + @Override + public final void setBatchingAllowed(boolean batchingAllowed) throws IOException { + base.setBatchingAllowed(batchingAllowed); + } + + + @Override + public final boolean getBatchingAllowed() { + return base.getBatchingAllowed(); + } + + + @Override + public final void flushBatch() throws IOException { + base.flushBatch(); + } + + + @Override + public final void sendPing(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + base.sendPing(applicationData); + } + + + @Override + public final void sendPong(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + base.sendPong(applicationData); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java b/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java new file mode 100644 index 00000000..2a93cc7b --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointBasic.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.nio.ByteBuffer; + +import javax.websocket.EncodeException; +import javax.websocket.RemoteEndpoint; + +public class WsRemoteEndpointBasic extends WsRemoteEndpointBase + implements RemoteEndpoint.Basic { + + WsRemoteEndpointBasic(WsRemoteEndpointImplBase base) { + super(base); + } + + + @Override + public void sendText(String text) throws IOException { + base.sendString(text); + } + + + @Override + public void sendBinary(ByteBuffer data) throws IOException { + base.sendBytes(data); + } + + + @Override + public void sendText(String fragment, boolean isLast) throws IOException { + base.sendPartialString(fragment, isLast); + } + + + @Override + public void sendBinary(ByteBuffer partialByte, boolean isLast) + throws IOException { + base.sendPartialBytes(partialByte, isLast); + } + + + @Override + public OutputStream getSendStream() throws IOException { + return base.getSendStream(); + } + + + @Override + public Writer getSendWriter() throws IOException { + return base.getSendWriter(); + } + + + @Override + public void sendObject(Object o) throws IOException, EncodeException { + base.sendObject(o); + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java new file mode 100644 index 00000000..776124fd --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplBase.java @@ -0,0 +1,1234 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.io.OutputStream; +import java.io.Writer; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CoderResult; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.EncodeException; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.buf.Utf8Encoder; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.Request; + +public abstract class WsRemoteEndpointImplBase implements RemoteEndpoint { + + private static final StringManager sm = + StringManager.getManager(WsRemoteEndpointImplBase.class); + + protected static final SendResult SENDRESULT_OK = new SendResult(); + + private final Log log = LogFactory.getLog(WsRemoteEndpointImplBase.class); // must not be static + + private final StateMachine stateMachine = new StateMachine(); + + private final IntermediateMessageHandler intermediateMessageHandler = + new IntermediateMessageHandler(this); + + private Transformation transformation = null; + private final Semaphore messagePartInProgress = new Semaphore(1); + private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>(); + private final Object messagePartLock = new Object(); + + // State + private volatile boolean closed = false; + private boolean fragmented = false; + private boolean nextFragmented = false; + private boolean text = false; + private boolean nextText = false; + + // Max size of WebSocket header is 14 bytes + private final ByteBuffer headerBuffer = ByteBuffer.allocate(14); + private final ByteBuffer outputBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final CharsetEncoder encoder = new Utf8Encoder(); + private final ByteBuffer encoderBuffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final AtomicBoolean batchingAllowed = new AtomicBoolean(false); + private volatile long sendTimeout = -1; + private WsSession wsSession; + private List<EncoderEntry> encoderEntries = new ArrayList<>(); + + private Request request; + + + protected void setTransformation(Transformation transformation) { + this.transformation = transformation; + } + + + public long getSendTimeout() { + return sendTimeout; + } + + + public void setSendTimeout(long timeout) { + this.sendTimeout = timeout; + } + + + @Override + public void setBatchingAllowed(boolean batchingAllowed) throws IOException { + boolean oldValue = this.batchingAllowed.getAndSet(batchingAllowed); + + if (oldValue && !batchingAllowed) { + flushBatch(); + } + } + + + @Override + public boolean getBatchingAllowed() { + return batchingAllowed.get(); + } + + + @Override + public void flushBatch() throws IOException { + sendMessageBlock(Constants.INTERNAL_OPCODE_FLUSH, null, true); + } + + + public void sendBytes(ByteBuffer data) throws IOException { + if (data == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.binaryStart(); + sendMessageBlock(Constants.OPCODE_BINARY, data, true); + stateMachine.complete(true); + } + + + public Future<Void> sendBytesByFuture(ByteBuffer data) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendBytesByCompletion(data, f2sh); + return f2sh; + } + + + public void sendBytesByCompletion(ByteBuffer data, SendHandler handler) { + if (data == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (handler == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + StateUpdateSendHandler sush = new StateUpdateSendHandler(handler, stateMachine); + stateMachine.binaryStart(); + startMessage(Constants.OPCODE_BINARY, data, true, sush); + } + + + public void sendPartialBytes(ByteBuffer partialByte, boolean last) + throws IOException { + if (partialByte == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.binaryPartialStart(); + sendMessageBlock(Constants.OPCODE_BINARY, partialByte, last); + stateMachine.complete(last); + } + + + @Override + public void sendPing(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + if (applicationData.remaining() > 125) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); + } + sendMessageBlock(Constants.OPCODE_PING, applicationData, true); + } + + + @Override + public void sendPong(ByteBuffer applicationData) throws IOException, + IllegalArgumentException { + if (applicationData.remaining() > 125) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.tooMuchData")); + } + sendMessageBlock(Constants.OPCODE_PONG, applicationData, true); + } + + + public void sendString(String text) throws IOException { + if (text == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.textStart(); + sendMessageBlock(CharBuffer.wrap(text), true); + } + + + public Future<Void> sendStringByFuture(String text) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendStringByCompletion(text, f2sh); + return f2sh; + } + + + public void sendStringByCompletion(String text, SendHandler handler) { + if (text == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (handler == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + stateMachine.textStart(); + TextMessageSendHandler tmsh = new TextMessageSendHandler(handler, + CharBuffer.wrap(text), true, encoder, encoderBuffer, this); + tmsh.write(); + // TextMessageSendHandler will update stateMachine when it completes + } + + + public void sendPartialString(String fragment, boolean isLast) + throws IOException { + if (fragment == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + stateMachine.textPartialStart(); + sendMessageBlock(CharBuffer.wrap(fragment), isLast); + } + + + public OutputStream getSendStream() { + stateMachine.streamStart(); + return new WsOutputStream(this); + } + + + public Writer getSendWriter() { + stateMachine.writeStart(); + return new WsWriter(this); + } + + + void sendMessageBlock(CharBuffer part, boolean last) throws IOException { + long timeoutExpiry = getTimeoutExpiry(); + boolean isDone = false; + while (!isDone) { + encoderBuffer.clear(); + CoderResult cr = encoder.encode(part, encoderBuffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + encoderBuffer.flip(); + sendMessageBlock(Constants.OPCODE_TEXT, encoderBuffer, last && isDone, timeoutExpiry); + } + stateMachine.complete(last); + } + + + void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last) + throws IOException { + sendMessageBlock(opCode, payload, last, getTimeoutExpiry()); + } + + + private long getTimeoutExpiry() { + // Get the timeout before we send the message. The message may + // trigger a session close and depending on timing the client + // session may close before we can read the timeout. + long timeout = getBlockingSendTimeout(); + if (timeout < 0) { + return Long.MAX_VALUE; + } else { + return System.currentTimeMillis() + timeout; + } + } + + private byte currentOpCode = Constants.OPCODE_CONTINUATION; + + private void sendMessageBlock(byte opCode, ByteBuffer payload, boolean last, + long timeoutExpiry) throws IOException { + wsSession.updateLastActive(); + + if (opCode == currentOpCode) { + opCode = Constants.OPCODE_CONTINUATION; + } + + request.sendWsFrame(payload, opCode, last, timeoutExpiry); + + if (!last && opCode != Constants.OPCODE_CONTINUATION) { + currentOpCode = opCode; + } + + if (last && opCode == Constants.OPCODE_CONTINUATION) { + currentOpCode = Constants.OPCODE_CONTINUATION; + } + } + + + void startMessage(byte opCode, ByteBuffer payload, boolean last, + SendHandler handler) { + + wsSession.updateLastActive(); + + List<MessagePart> messageParts = new ArrayList<>(); + messageParts.add(new MessagePart(last, 0, opCode, payload, + intermediateMessageHandler, + new EndMessageHandler(this, handler), -1)); + + messageParts = transformation.sendMessagePart(messageParts); + + // Some extensions/transformations may buffer messages so it is possible + // that no message parts will be returned. If this is the case the + // trigger the supplied SendHandler + if (messageParts.size() == 0) { + handler.onResult(new SendResult()); + return; + } + + MessagePart mp = messageParts.remove(0); + + boolean doWrite = false; + synchronized (messagePartLock) { + if (Constants.OPCODE_CLOSE == mp.getOpCode() && getBatchingAllowed()) { + // Should not happen. To late to send batched messages now since + // the session has been closed. Complain loudly. + log.warn(sm.getString("wsRemoteEndpoint.flushOnCloseFailed")); + } + if (messagePartInProgress.tryAcquire()) { + doWrite = true; + } else { + // When a control message is sent while another message is being + // sent, the control message is queued. Chances are the + // subsequent data message part will end up queued while the + // control message is sent. The logic in this class (state + // machine, EndMessageHandler, TextMessageSendHandler) ensures + // that there will only ever be one data message part in the + // queue. There could be multiple control messages in the queue. + + // Add it to the queue + messagePartQueue.add(mp); + } + // Add any remaining messages to the queue + messagePartQueue.addAll(messageParts); + } + if (doWrite) { + // Actual write has to be outside sync block to avoid possible + // deadlock between messagePartLock and writeLock in + // o.a.coyote.http11.upgrade.AbstractServletOutputStream + writeMessagePart(mp); + } + } + + + void endMessage(SendHandler handler, SendResult result) { + boolean doWrite = false; + MessagePart mpNext = null; + synchronized (messagePartLock) { + + fragmented = nextFragmented; + text = nextText; + + mpNext = messagePartQueue.poll(); + if (mpNext == null) { + messagePartInProgress.release(); + } else if (!closed){ + // Session may have been closed unexpectedly in the middle of + // sending a fragmented message closing the endpoint. If this + // happens, clearly there is no point trying to send the rest of + // the message. + doWrite = true; + } + } + if (doWrite) { + // Actual write has to be outside sync block to avoid possible + // deadlock between messagePartLock and writeLock in + // o.a.coyote.http11.upgrade.AbstractServletOutputStream + writeMessagePart(mpNext); + } + + wsSession.updateLastActive(); + + // Some handlers, such as the IntermediateMessageHandler, do not have a + // nested handler so handler may be null. + if (handler != null) { + handler.onResult(result); + } + } + + + void writeMessagePart(MessagePart mp) { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closed")); + } + + if (Constants.INTERNAL_OPCODE_FLUSH == mp.getOpCode()) { + nextFragmented = fragmented; + nextText = text; + outputBuffer.flip(); + SendHandler flushHandler = new OutputBufferFlushSendHandler( + outputBuffer, mp.getEndHandler()); + doWrite(flushHandler, mp.getBlockingWriteTimeoutExpiry(), outputBuffer); + return; + } + + // Control messages may be sent in the middle of fragmented message + // so they have no effect on the fragmented or text flags + boolean first; + if (Util.isControl(mp.getOpCode())) { + nextFragmented = fragmented; + nextText = text; + if (mp.getOpCode() == Constants.OPCODE_CLOSE) { + closed = true; + } + first = true; + } else { + boolean isText = Util.isText(mp.getOpCode()); + + if (fragmented) { + // Currently fragmented + if (text != isText) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.changeType")); + } + nextText = text; + nextFragmented = !mp.isFin(); + first = false; + } else { + // Wasn't fragmented. Might be now + if (mp.isFin()) { + nextFragmented = false; + } else { + nextFragmented = true; + nextText = isText; + } + first = true; + } + } + + byte[] mask; + + if (isMasked()) { + mask = Util.generateMask(); + } else { + mask = null; + } + + headerBuffer.clear(); + writeHeader(headerBuffer, mp.isFin(), mp.getRsv(), mp.getOpCode(), + isMasked(), mp.getPayload(), mask, first); + headerBuffer.flip(); + + if (getBatchingAllowed() || isMasked()) { + // Need to write via output buffer + OutputBufferSendHandler obsh = new OutputBufferSendHandler( + mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), + headerBuffer, mp.getPayload(), mask, + outputBuffer, !getBatchingAllowed(), this); + obsh.write(); + } else { + // Can write directly + doWrite(mp.getEndHandler(), mp.getBlockingWriteTimeoutExpiry(), + headerBuffer, mp.getPayload()); + } + } + + + private long getBlockingSendTimeout() { + Object obj = wsSession.getUserProperties().get(Constants.BLOCKING_SEND_TIMEOUT_PROPERTY); + Long userTimeout = null; + if (obj instanceof Long) { + userTimeout = (Long) obj; + } + if (userTimeout == null) { + return Constants.DEFAULT_BLOCKING_SEND_TIMEOUT; + } else { + return userTimeout.longValue(); + } + } + + + /** + * Wraps the user provided handler so that the end point is notified when + * the message is complete. + */ + private static class EndMessageHandler implements SendHandler { + + private final WsRemoteEndpointImplBase endpoint; + private final SendHandler handler; + + public EndMessageHandler(WsRemoteEndpointImplBase endpoint, + SendHandler handler) { + this.endpoint = endpoint; + this.handler = handler; + } + + + @Override + public void onResult(SendResult result) { + endpoint.endMessage(handler, result); + } + } + + + /** + * If a transformation needs to split a {@link MessagePart} into multiple + * {@link MessagePart}s, it uses this handler as the end handler for each of + * the additional {@link MessagePart}s. This handler notifies this this + * class that the {@link MessagePart} has been processed and that the next + * {@link MessagePart} in the queue should be started. The final + * {@link MessagePart} will use the {@link EndMessageHandler} provided with + * the original {@link MessagePart}. + */ + private static class IntermediateMessageHandler implements SendHandler { + + private final WsRemoteEndpointImplBase endpoint; + + public IntermediateMessageHandler(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + + @Override + public void onResult(SendResult result) { + endpoint.endMessage(null, result); + } + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void sendObject(Object obj) throws IOException, EncodeException { + if (obj == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + /* + * Note that the implementation will convert primitives and their object + * equivalents by default but that users are free to specify their own + * encoders and decoders for this if they wish. + */ + Encoder encoder = findEncoder(obj); + if (encoder == null && Util.isPrimitive(obj.getClass())) { + String msg = obj.toString(); + sendString(msg); + return; + } + if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { + ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); + sendBytes(msg); + return; + } + + if (encoder instanceof Encoder.Text) { + String msg = ((Encoder.Text) encoder).encode(obj); + sendString(msg); + } else if (encoder instanceof Encoder.TextStream) { + try (Writer w = getSendWriter()) { + ((Encoder.TextStream) encoder).encode(obj, w); + } + } else if (encoder instanceof Encoder.Binary) { + ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); + sendBytes(msg); + } else if (encoder instanceof Encoder.BinaryStream) { + try (OutputStream os = getSendStream()) { + ((Encoder.BinaryStream) encoder).encode(obj, os); + } + } else { + throw new EncodeException(obj, sm.getString( + "wsRemoteEndpoint.noEncoder", obj.getClass())); + } + } + + + public Future<Void> sendObjectByFuture(Object obj) { + FutureToSendHandler f2sh = new FutureToSendHandler(wsSession); + sendObjectByCompletion(obj, f2sh); + return f2sh; + } + + + @SuppressWarnings({"unchecked", "rawtypes"}) + public void sendObjectByCompletion(Object obj, SendHandler completion) { + + if (obj == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullData")); + } + if (completion == null) { + throw new IllegalArgumentException(sm.getString("wsRemoteEndpoint.nullHandler")); + } + + /* + * Note that the implementation will convert primitives and their object + * equivalents by default but that users are free to specify their own + * encoders and decoders for this if they wish. + */ + Encoder encoder = findEncoder(obj); + if (encoder == null && Util.isPrimitive(obj.getClass())) { + String msg = obj.toString(); + sendStringByCompletion(msg, completion); + return; + } + if (encoder == null && byte[].class.isAssignableFrom(obj.getClass())) { + ByteBuffer msg = ByteBuffer.wrap((byte[]) obj); + sendBytesByCompletion(msg, completion); + return; + } + + try { + if (encoder instanceof Encoder.Text) { + String msg = ((Encoder.Text) encoder).encode(obj); + sendStringByCompletion(msg, completion); + } else if (encoder instanceof Encoder.TextStream) { + try (Writer w = getSendWriter()) { + ((Encoder.TextStream) encoder).encode(obj, w); + } + completion.onResult(new SendResult()); + } else if (encoder instanceof Encoder.Binary) { + ByteBuffer msg = ((Encoder.Binary) encoder).encode(obj); + sendBytesByCompletion(msg, completion); + } else if (encoder instanceof Encoder.BinaryStream) { + try (OutputStream os = getSendStream()) { + ((Encoder.BinaryStream) encoder).encode(obj, os); + } + completion.onResult(new SendResult()); + } else { + throw new EncodeException(obj, sm.getString( + "wsRemoteEndpoint.noEncoder", obj.getClass())); + } + } catch (Exception e) { + SendResult sr = new SendResult(e); + completion.onResult(sr); + } + } + + + protected void setSession(WsSession wsSession) { + this.wsSession = wsSession; + } + + + protected void setRequest(Request request) { + this.request = request; + } + + protected void setEncoders(EndpointConfig endpointConfig) + throws DeploymentException { + encoderEntries.clear(); + for (Class<? extends Encoder> encoderClazz : + endpointConfig.getEncoders()) { + Encoder instance; + try { + instance = encoderClazz.getConstructor().newInstance(); + instance.init(endpointConfig); + } catch (ReflectiveOperationException e) { + throw new DeploymentException( + sm.getString("wsRemoteEndpoint.invalidEncoder", + encoderClazz.getName()), e); + } + EncoderEntry entry = new EncoderEntry( + Util.getEncoderType(encoderClazz), instance); + encoderEntries.add(entry); + } + } + + + private Encoder findEncoder(Object obj) { + for (EncoderEntry entry : encoderEntries) { + if (entry.getClazz().isAssignableFrom(obj.getClass())) { + return entry.getEncoder(); + } + } + return null; + } + + + public final void close() { + for (EncoderEntry entry : encoderEntries) { + entry.getEncoder().destroy(); + } + + request.closeWs(); + } + + + protected abstract void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... data); + protected abstract boolean isMasked(); + protected abstract void doClose(); + + private static void writeHeader(ByteBuffer headerBuffer, boolean fin, + int rsv, byte opCode, boolean masked, ByteBuffer payload, + byte[] mask, boolean first) { + + byte b = 0; + + if (fin) { + // Set the fin bit + b -= 128; + } + + b += (rsv << 4); + + if (first) { + // This is the first fragment of this message + b += opCode; + } + // If not the first fragment, it is a continuation with opCode of zero + + headerBuffer.put(b); + + if (masked) { + b = (byte) 0x80; + } else { + b = 0; + } + + // Next write the mask && length length + if (payload.limit() < 126) { + headerBuffer.put((byte) (payload.limit() | b)); + } else if (payload.limit() < 65536) { + headerBuffer.put((byte) (126 | b)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } else { + // Will never be more than 2^31-1 + headerBuffer.put((byte) (127 | b)); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) 0); + headerBuffer.put((byte) (payload.limit() >>> 24)); + headerBuffer.put((byte) (payload.limit() >>> 16)); + headerBuffer.put((byte) (payload.limit() >>> 8)); + headerBuffer.put((byte) (payload.limit() & 0xFF)); + } + if (masked) { + headerBuffer.put(mask[0]); + headerBuffer.put(mask[1]); + headerBuffer.put(mask[2]); + headerBuffer.put(mask[3]); + } + } + + + private class TextMessageSendHandler implements SendHandler { + + private final SendHandler handler; + private final CharBuffer message; + private final boolean isLast; + private final CharsetEncoder encoder; + private final ByteBuffer buffer; + private final WsRemoteEndpointImplBase endpoint; + private volatile boolean isDone = false; + + public TextMessageSendHandler(SendHandler handler, CharBuffer message, + boolean isLast, CharsetEncoder encoder, + ByteBuffer encoderBuffer, WsRemoteEndpointImplBase endpoint) { + this.handler = handler; + this.message = message; + this.isLast = isLast; + this.encoder = encoder.reset(); + this.buffer = encoderBuffer; + this.endpoint = endpoint; + } + + public void write() { + buffer.clear(); + CoderResult cr = encoder.encode(message, buffer, true); + if (cr.isError()) { + throw new IllegalArgumentException(cr.toString()); + } + isDone = !cr.isOverflow(); + buffer.flip(); + endpoint.startMessage(Constants.OPCODE_TEXT, buffer, + isDone && isLast, this); + } + + @Override + public void onResult(SendResult result) { + if (isDone) { + endpoint.stateMachine.complete(isLast); + handler.onResult(result); + } else if(!result.isOK()) { + handler.onResult(result); + } else if (closed){ + SendResult sr = new SendResult(new IOException( + sm.getString("wsRemoteEndpoint.closedDuringMessage"))); + handler.onResult(sr); + } else { + write(); + } + } + } + + + /** + * Used to write data to the output buffer, flushing the buffer if it fills + * up. + */ + private static class OutputBufferSendHandler implements SendHandler { + + private final SendHandler handler; + private final long blockingWriteTimeoutExpiry; + private final ByteBuffer headerBuffer; + private final ByteBuffer payload; + private final byte[] mask; + private final ByteBuffer outputBuffer; + private final boolean flushRequired; + private final WsRemoteEndpointImplBase endpoint; + private int maskIndex = 0; + + public OutputBufferSendHandler(SendHandler completion, + long blockingWriteTimeoutExpiry, + ByteBuffer headerBuffer, ByteBuffer payload, byte[] mask, + ByteBuffer outputBuffer, boolean flushRequired, + WsRemoteEndpointImplBase endpoint) { + this.blockingWriteTimeoutExpiry = blockingWriteTimeoutExpiry; + this.handler = completion; + this.headerBuffer = headerBuffer; + this.payload = payload; + this.mask = mask; + this.outputBuffer = outputBuffer; + this.flushRequired = flushRequired; + this.endpoint = endpoint; + } + + public void write() { + // Write the header + while (headerBuffer.hasRemaining() && outputBuffer.hasRemaining()) { + outputBuffer.put(headerBuffer.get()); + } + if (headerBuffer.hasRemaining()) { + // Still more headers to write, need to flush + outputBuffer.flip(); + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + return; + } + + // Write the payload + int payloadLeft = payload.remaining(); + int payloadLimit = payload.limit(); + int outputSpace = outputBuffer.remaining(); + int toWrite = payloadLeft; + + if (payloadLeft > outputSpace) { + toWrite = outputSpace; + // Temporarily reduce the limit + payload.limit(payload.position() + toWrite); + } + + if (mask == null) { + // Use a bulk copy + outputBuffer.put(payload); + } else { + for (int i = 0; i < toWrite; i++) { + outputBuffer.put( + (byte) (payload.get() ^ (mask[maskIndex++] & 0xFF))); + if (maskIndex > 3) { + maskIndex = 0; + } + } + } + + if (payloadLeft > outputSpace) { + // Restore the original limit + payload.limit(payloadLimit); + // Still more data to write, need to flush + outputBuffer.flip(); + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + return; + } + + if (flushRequired) { + outputBuffer.flip(); + if (outputBuffer.remaining() == 0) { + handler.onResult(SENDRESULT_OK); + } else { + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + } + } else { + handler.onResult(SENDRESULT_OK); + } + } + + // ------------------------------------------------- SendHandler methods + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + if (outputBuffer.hasRemaining()) { + endpoint.doWrite(this, blockingWriteTimeoutExpiry, outputBuffer); + } else { + outputBuffer.clear(); + write(); + } + } else { + handler.onResult(result); + } + } + } + + + /** + * Ensures that the output buffer is cleared after it has been flushed. + */ + private static class OutputBufferFlushSendHandler implements SendHandler { + + private final ByteBuffer outputBuffer; + private final SendHandler handler; + + public OutputBufferFlushSendHandler(ByteBuffer outputBuffer, SendHandler handler) { + this.outputBuffer = outputBuffer; + this.handler = handler; + } + + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + outputBuffer.clear(); + } + handler.onResult(result); + } + } + + + private static class WsOutputStream extends OutputStream { + + private final WsRemoteEndpointImplBase endpoint; + private final ByteBuffer buffer = ByteBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Object closeLock = new Object(); + private volatile boolean closed = false; + private volatile boolean used = false; + + public WsOutputStream(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + @Override + public void write(int b) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + buffer.put((byte) b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + if (len == 0) { + return; + } + if ((off < 0) || (off > b.length) || (len < 0) || + ((off + len) > b.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + int remaining = buffer.remaining(); + int written = 0; + + while (remaining < len - written) { + buffer.put(b, off + written, remaining); + written += remaining; + flush(); + remaining = buffer.remaining(); + } + buffer.put(b, off + written, len - written); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedOutputStream")); + } + + // Optimisation. If there is no data to flush then do not send an + // empty message. + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { + doWrite(false); + } + } + + @Override + public void close() throws IOException { + synchronized (closeLock) { + if (closed) { + return; + } + closed = true; + } + + doWrite(true); + } + + private void doWrite(boolean last) throws IOException { + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { + buffer.flip(); + endpoint.sendMessageBlock(Constants.OPCODE_BINARY, buffer, last); + } + endpoint.stateMachine.complete(last); + buffer.clear(); + } + } + + + private static class WsWriter extends Writer { + + private final WsRemoteEndpointImplBase endpoint; + private final CharBuffer buffer = CharBuffer.allocate(Constants.DEFAULT_BUFFER_SIZE); + private final Object closeLock = new Object(); + private volatile boolean closed = false; + private volatile boolean used = false; + + public WsWriter(WsRemoteEndpointImplBase endpoint) { + this.endpoint = endpoint; + } + + @Override + public void write(char[] cbuf, int off, int len) throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedWriter")); + } + if (len == 0) { + return; + } + if ((off < 0) || (off > cbuf.length) || (len < 0) || + ((off + len) > cbuf.length) || ((off + len) < 0)) { + throw new IndexOutOfBoundsException(); + } + + used = true; + if (buffer.remaining() == 0) { + flush(); + } + int remaining = buffer.remaining(); + int written = 0; + + while (remaining < len - written) { + buffer.put(cbuf, off + written, remaining); + written += remaining; + flush(); + remaining = buffer.remaining(); + } + buffer.put(cbuf, off + written, len - written); + } + + @Override + public void flush() throws IOException { + if (closed) { + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.closedWriter")); + } + + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || buffer.position() > 0) { + doWrite(false); + } + } + + @Override + public void close() throws IOException { + synchronized (closeLock) { + if (closed) { + return; + } + closed = true; + } + + doWrite(true); + } + + private void doWrite(boolean last) throws IOException { + if (!Constants.STREAMS_DROP_EMPTY_MESSAGES || used) { + buffer.flip(); + endpoint.sendMessageBlock(buffer, last); + buffer.clear(); + } else { + endpoint.stateMachine.complete(last); + } + } + } + + + private static class EncoderEntry { + + private final Class<?> clazz; + private final Encoder encoder; + + public EncoderEntry(Class<?> clazz, Encoder encoder) { + this.clazz = clazz; + this.encoder = encoder; + } + + public Class<?> getClazz() { + return clazz; + } + + public Encoder getEncoder() { + return encoder; + } + } + + + private enum State { + OPEN, + STREAM_WRITING, + WRITER_WRITING, + BINARY_PARTIAL_WRITING, + BINARY_PARTIAL_READY, + BINARY_FULL_WRITING, + TEXT_PARTIAL_WRITING, + TEXT_PARTIAL_READY, + TEXT_FULL_WRITING + } + + + private static class StateMachine { + private State state = State.OPEN; + + public synchronized void streamStart() { + checkState(State.OPEN); + state = State.STREAM_WRITING; + } + + public synchronized void writeStart() { + checkState(State.OPEN); + state = State.WRITER_WRITING; + } + + public synchronized void binaryPartialStart() { + checkState(State.OPEN, State.BINARY_PARTIAL_READY); + state = State.BINARY_PARTIAL_WRITING; + } + + public synchronized void binaryStart() { + checkState(State.OPEN); + state = State.BINARY_FULL_WRITING; + } + + public synchronized void textPartialStart() { + checkState(State.OPEN, State.TEXT_PARTIAL_READY); + state = State.TEXT_PARTIAL_WRITING; + } + + public synchronized void textStart() { + checkState(State.OPEN); + state = State.TEXT_FULL_WRITING; + } + + public synchronized void complete(boolean last) { + if (last) { + checkState(State.TEXT_PARTIAL_WRITING, State.TEXT_FULL_WRITING, + State.BINARY_PARTIAL_WRITING, State.BINARY_FULL_WRITING, + State.STREAM_WRITING, State.WRITER_WRITING); + state = State.OPEN; + } else { + checkState(State.TEXT_PARTIAL_WRITING, State.BINARY_PARTIAL_WRITING, + State.STREAM_WRITING, State.WRITER_WRITING); + if (state == State.TEXT_PARTIAL_WRITING) { + state = State.TEXT_PARTIAL_READY; + } else if (state == State.BINARY_PARTIAL_WRITING){ + state = State.BINARY_PARTIAL_READY; + } else if (state == State.WRITER_WRITING) { + // NO-OP. Leave state as is. + } else if (state == State.STREAM_WRITING) { + // NO-OP. Leave state as is. + } else { + // Should never happen + // The if ... else ... blocks above should cover all states + // permitted by the preceding checkState() call + throw new IllegalStateException( + "BUG: This code should never be called"); + } + } + } + + private void checkState(State... required) { + for (State state : required) { + if (this.state == state) { + return; + } + } + throw new IllegalStateException( + sm.getString("wsRemoteEndpoint.wrongState", this.state)); + } + } + + + private static class StateUpdateSendHandler implements SendHandler { + + private final SendHandler handler; + private final StateMachine stateMachine; + + public StateUpdateSendHandler(SendHandler handler, StateMachine stateMachine) { + this.handler = handler; + this.stateMachine = stateMachine; + } + + @Override + public void onResult(SendResult result) { + if (result.isOK()) { + stateMachine.complete(true); + } + handler.onResult(result); + } + } + + + private static class BlockingSendHandler implements SendHandler { + + private SendResult sendResult = null; + + @Override + public void onResult(SendResult result) { + sendResult = result; + } + + public SendResult getSendResult() { + return sendResult; + } + } +} diff --git a/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java b/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java new file mode 100644 index 00000000..70b66789 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsRemoteEndpointImplClient.java @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +public class WsRemoteEndpointImplClient extends WsRemoteEndpointImplBase { + + private final AsyncChannelWrapper channel; + + public WsRemoteEndpointImplClient(AsyncChannelWrapper channel) { + this.channel = channel; + } + + + @Override + protected boolean isMasked() { + return true; + } + + + @Override + protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... data) { + long timeout; + for (ByteBuffer byteBuffer : data) { + if (blockingWriteTimeoutExpiry == -1) { + timeout = getSendTimeout(); + if (timeout < 1) { + timeout = Long.MAX_VALUE; + } + } else { + timeout = blockingWriteTimeoutExpiry - System.currentTimeMillis(); + if (timeout < 0) { + SendResult sr = new SendResult(new IOException("Blocking write timeout")); + handler.onResult(sr); + } + } + + try { + channel.write(byteBuffer).get(timeout, TimeUnit.MILLISECONDS); + } catch (InterruptedException | ExecutionException | TimeoutException e) { + handler.onResult(new SendResult(e)); + return; + } + } + handler.onResult(SENDRESULT_OK); + } + + @Override + protected void doClose() { + channel.close(); + } +} diff --git a/src/java/nginx/unit/websocket/WsSession.java b/src/java/nginx/unit/websocket/WsSession.java new file mode 100644 index 00000000..b654eb37 --- /dev/null +++ b/src/java/nginx/unit/websocket/WsSession.java @@ -0,0 +1,1070 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.IOException; +import java.net.URI; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.CharBuffer; +import java.nio.channels.WritePendingException; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.StandardCharsets; +import java.security.Principal; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; + +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCode; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.MessageHandler.Partial; +import javax.websocket.MessageHandler.Whole; +import javax.websocket.PongMessage; +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendResult; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.InstanceManagerBindings; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.buf.Utf8Decoder; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.Request; + +public class WsSession implements Session { + + // An ellipsis is a single character that looks like three periods in a row + // and is used to indicate a continuation. + private static final byte[] ELLIPSIS_BYTES = "\u2026".getBytes(StandardCharsets.UTF_8); + // An ellipsis is three bytes in UTF-8 + private static final int ELLIPSIS_BYTES_LEN = ELLIPSIS_BYTES.length; + + private static final StringManager sm = StringManager.getManager(WsSession.class); + private static AtomicLong ids = new AtomicLong(0); + + private final Log log = LogFactory.getLog(WsSession.class); // must not be static + + private final CharsetDecoder utf8DecoderMessage = new Utf8Decoder(). + onMalformedInput(CodingErrorAction.REPORT). + onUnmappableCharacter(CodingErrorAction.REPORT); + + private final Endpoint localEndpoint; + private final WsRemoteEndpointImplBase wsRemoteEndpoint; + private final RemoteEndpoint.Async remoteEndpointAsync; + private final RemoteEndpoint.Basic remoteEndpointBasic; + private final ClassLoader applicationClassLoader; + private final WsWebSocketContainer webSocketContainer; + private final URI requestUri; + private final Map<String, List<String>> requestParameterMap; + private final String queryString; + private final Principal userPrincipal; + private final EndpointConfig endpointConfig; + + private final List<Extension> negotiatedExtensions; + private final String subProtocol; + private final Map<String, String> pathParameters; + private final boolean secure; + private final String httpSessionId; + private final String id; + + // Expected to handle message types of <String> only + private volatile MessageHandler textMessageHandler = null; + // Expected to handle message types of <ByteBuffer> only + private volatile MessageHandler binaryMessageHandler = null; + private volatile MessageHandler.Whole<PongMessage> pongMessageHandler = null; + private volatile State state = State.OPEN; + private final Object stateLock = new Object(); + private final Map<String, Object> userProperties = new ConcurrentHashMap<>(); + private volatile int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile long maxIdleTimeout = 0; + private volatile long lastActive = System.currentTimeMillis(); + private Map<FutureToSendHandler, FutureToSendHandler> futures = new ConcurrentHashMap<>(); + + private CharBuffer messageBufferText; + private ByteBuffer binaryBuffer; + private byte startOpCode = Constants.OPCODE_CONTINUATION; + + /** + * Creates a new WebSocket session for communication between the two + * provided end points. The result of {@link Thread#getContextClassLoader()} + * at the time this constructor is called will be used when calling + * {@link Endpoint#onClose(Session, CloseReason)}. + * + * @param localEndpoint The end point managed by this code + * @param wsRemoteEndpoint The other / remote endpoint + * @param wsWebSocketContainer The container that created this session + * @param requestUri The URI used to connect to this endpoint or + * <code>null</code> is this is a client session + * @param requestParameterMap The parameters associated with the request + * that initiated this session or + * <code>null</code> if this is a client session + * @param queryString The query string associated with the request + * that initiated this session or + * <code>null</code> if this is a client session + * @param userPrincipal The principal associated with the request + * that initiated this session or + * <code>null</code> if this is a client session + * @param httpSessionId The HTTP session ID associated with the + * request that initiated this session or + * <code>null</code> if this is a client session + * @param negotiatedExtensions The agreed extensions to use for this session + * @param subProtocol The agreed subprotocol to use for this + * session + * @param pathParameters The path parameters associated with the + * request that initiated this session or + * <code>null</code> if this is a client session + * @param secure Was this session initiated over a secure + * connection? + * @param endpointConfig The configuration information for the + * endpoint + * @throws DeploymentException if an invalid encode is specified + */ + public WsSession(Endpoint localEndpoint, + WsRemoteEndpointImplBase wsRemoteEndpoint, + WsWebSocketContainer wsWebSocketContainer, + URI requestUri, Map<String, List<String>> requestParameterMap, + String queryString, Principal userPrincipal, String httpSessionId, + List<Extension> negotiatedExtensions, String subProtocol, Map<String, String> pathParameters, + boolean secure, EndpointConfig endpointConfig, + Request request) throws DeploymentException { + this.localEndpoint = localEndpoint; + this.wsRemoteEndpoint = wsRemoteEndpoint; + this.wsRemoteEndpoint.setSession(this); + this.wsRemoteEndpoint.setRequest(request); + + request.setWsSession(this); + + this.remoteEndpointAsync = new WsRemoteEndpointAsync(wsRemoteEndpoint); + this.remoteEndpointBasic = new WsRemoteEndpointBasic(wsRemoteEndpoint); + this.webSocketContainer = wsWebSocketContainer; + applicationClassLoader = Thread.currentThread().getContextClassLoader(); + wsRemoteEndpoint.setSendTimeout(wsWebSocketContainer.getDefaultAsyncSendTimeout()); + this.maxBinaryMessageBufferSize = webSocketContainer.getDefaultMaxBinaryMessageBufferSize(); + this.maxTextMessageBufferSize = webSocketContainer.getDefaultMaxTextMessageBufferSize(); + this.maxIdleTimeout = webSocketContainer.getDefaultMaxSessionIdleTimeout(); + this.requestUri = requestUri; + if (requestParameterMap == null) { + this.requestParameterMap = Collections.emptyMap(); + } else { + this.requestParameterMap = requestParameterMap; + } + this.queryString = queryString; + this.userPrincipal = userPrincipal; + this.httpSessionId = httpSessionId; + this.negotiatedExtensions = negotiatedExtensions; + if (subProtocol == null) { + this.subProtocol = ""; + } else { + this.subProtocol = subProtocol; + } + this.pathParameters = pathParameters; + this.secure = secure; + this.wsRemoteEndpoint.setEncoders(endpointConfig); + this.endpointConfig = endpointConfig; + + this.userProperties.putAll(endpointConfig.getUserProperties()); + this.id = Long.toHexString(ids.getAndIncrement()); + + InstanceManager instanceManager = webSocketContainer.getInstanceManager(); + if (instanceManager == null) { + instanceManager = InstanceManagerBindings.get(applicationClassLoader); + } + if (instanceManager != null) { + try { + instanceManager.newInstance(localEndpoint); + } catch (Exception e) { + throw new DeploymentException(sm.getString("wsSession.instanceNew"), e); + } + } + + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.created", id)); + } + + messageBufferText = CharBuffer.allocate(maxTextMessageBufferSize); + } + + public static String wsSession_test() { + return sm.getString("wsSession.instanceNew"); + } + + + @Override + public WebSocketContainer getContainer() { + checkState(); + return webSocketContainer; + } + + + @Override + public void addMessageHandler(MessageHandler listener) { + Class<?> target = Util.getMessageType(listener); + doAddMessageHandler(target, listener); + } + + + @Override + public <T> void addMessageHandler(Class<T> clazz, Partial<T> handler) + throws IllegalStateException { + doAddMessageHandler(clazz, handler); + } + + + @Override + public <T> void addMessageHandler(Class<T> clazz, Whole<T> handler) + throws IllegalStateException { + doAddMessageHandler(clazz, handler); + } + + + @SuppressWarnings("unchecked") + private void doAddMessageHandler(Class<?> target, MessageHandler listener) { + checkState(); + + // Message handlers that require decoders may map to text messages, + // binary messages, both or neither. + + // The frame processing code expects binary message handlers to + // accept ByteBuffer + + // Use the POJO message handler wrappers as they are designed to wrap + // arbitrary objects with MessageHandlers and can wrap MessageHandlers + // just as easily. + + Set<MessageHandlerResult> mhResults = Util.getMessageHandlers(target, listener, + endpointConfig, this); + + for (MessageHandlerResult mhResult : mhResults) { + switch (mhResult.getType()) { + case TEXT: { + if (textMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerText")); + } + textMessageHandler = mhResult.getHandler(); + break; + } + case BINARY: { + if (binaryMessageHandler != null) { + throw new IllegalStateException( + sm.getString("wsSession.duplicateHandlerBinary")); + } + binaryMessageHandler = mhResult.getHandler(); + break; + } + case PONG: { + if (pongMessageHandler != null) { + throw new IllegalStateException(sm.getString("wsSession.duplicateHandlerPong")); + } + MessageHandler handler = mhResult.getHandler(); + if (handler instanceof MessageHandler.Whole<?>) { + pongMessageHandler = (MessageHandler.Whole<PongMessage>) handler; + } else { + throw new IllegalStateException( + sm.getString("wsSession.invalidHandlerTypePong")); + } + + break; + } + default: { + throw new IllegalArgumentException( + sm.getString("wsSession.unknownHandlerType", listener, mhResult.getType())); + } + } + } + } + + + @Override + public Set<MessageHandler> getMessageHandlers() { + checkState(); + Set<MessageHandler> result = new HashSet<>(); + if (binaryMessageHandler != null) { + result.add(binaryMessageHandler); + } + if (textMessageHandler != null) { + result.add(textMessageHandler); + } + if (pongMessageHandler != null) { + result.add(pongMessageHandler); + } + return result; + } + + + @Override + public void removeMessageHandler(MessageHandler listener) { + checkState(); + if (listener == null) { + return; + } + + MessageHandler wrapped = null; + + if (listener instanceof WrappedMessageHandler) { + wrapped = ((WrappedMessageHandler) listener).getWrappedHandler(); + } + + if (wrapped == null) { + wrapped = listener; + } + + boolean removed = false; + if (wrapped.equals(textMessageHandler) || listener.equals(textMessageHandler)) { + textMessageHandler = null; + removed = true; + } + + if (wrapped.equals(binaryMessageHandler) || listener.equals(binaryMessageHandler)) { + binaryMessageHandler = null; + removed = true; + } + + if (wrapped.equals(pongMessageHandler) || listener.equals(pongMessageHandler)) { + pongMessageHandler = null; + removed = true; + } + + if (!removed) { + // ISE for now. Could swallow this silently / log this if the ISE + // becomes a problem + throw new IllegalStateException( + sm.getString("wsSession.removeHandlerFailed", listener)); + } + } + + + @Override + public String getProtocolVersion() { + checkState(); + return Constants.WS_VERSION_HEADER_VALUE; + } + + + @Override + public String getNegotiatedSubprotocol() { + checkState(); + return subProtocol; + } + + + @Override + public List<Extension> getNegotiatedExtensions() { + checkState(); + return negotiatedExtensions; + } + + + @Override + public boolean isSecure() { + checkState(); + return secure; + } + + + @Override + public boolean isOpen() { + return state == State.OPEN; + } + + + @Override + public long getMaxIdleTimeout() { + checkState(); + return maxIdleTimeout; + } + + + @Override + public void setMaxIdleTimeout(long timeout) { + checkState(); + this.maxIdleTimeout = timeout; + } + + + @Override + public void setMaxBinaryMessageBufferSize(int max) { + checkState(); + this.maxBinaryMessageBufferSize = max; + } + + + @Override + public int getMaxBinaryMessageBufferSize() { + checkState(); + return maxBinaryMessageBufferSize; + } + + + @Override + public void setMaxTextMessageBufferSize(int max) { + checkState(); + this.maxTextMessageBufferSize = max; + } + + + @Override + public int getMaxTextMessageBufferSize() { + checkState(); + return maxTextMessageBufferSize; + } + + + @Override + public Set<Session> getOpenSessions() { + checkState(); + return webSocketContainer.getOpenSessions(localEndpoint); + } + + + @Override + public RemoteEndpoint.Async getAsyncRemote() { + checkState(); + return remoteEndpointAsync; + } + + + @Override + public RemoteEndpoint.Basic getBasicRemote() { + checkState(); + return remoteEndpointBasic; + } + + + @Override + public void close() throws IOException { + close(new CloseReason(CloseCodes.NORMAL_CLOSURE, "")); + } + + + @Override + public void close(CloseReason closeReason) throws IOException { + doClose(closeReason, closeReason); + } + + + /** + * WebSocket 1.0. Section 2.1.5. + * Need internal close method as spec requires that the local endpoint + * receives a 1006 on timeout. + * + * @param closeReasonMessage The close reason to pass to the remote endpoint + * @param closeReasonLocal The close reason to pass to the local endpoint + */ + public void doClose(CloseReason closeReasonMessage, CloseReason closeReasonLocal) { + // Double-checked locking. OK because state is volatile + if (state != State.OPEN) { + return; + } + + synchronized (stateLock) { + if (state != State.OPEN) { + return; + } + + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.doClose", id)); + } + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + + state = State.OUTPUT_CLOSED; + + sendCloseMessage(closeReasonMessage); + fireEndpointOnClose(closeReasonLocal); + } + + IOException ioe = new IOException(sm.getString("wsSession.messageFailed")); + SendResult sr = new SendResult(ioe); + for (FutureToSendHandler f2sh : futures.keySet()) { + f2sh.onResult(sr); + } + } + + + /** + * Called when a close message is received. Should only ever happen once. + * Also called after a protocol error when the ProtocolHandler needs to + * force the closing of the connection. + * + * @param closeReason The reason contained within the received close + * message. + */ + public void onClose(CloseReason closeReason) { + + synchronized (stateLock) { + if (state != State.CLOSED) { + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + if (state == State.OPEN) { + state = State.OUTPUT_CLOSED; + sendCloseMessage(closeReason); + fireEndpointOnClose(closeReason); + } + state = State.CLOSED; + + // Close the socket + wsRemoteEndpoint.close(); + } + } + } + + + public void onClose() { + + synchronized (stateLock) { + if (state != State.CLOSED) { + try { + wsRemoteEndpoint.setBatchingAllowed(false); + } catch (IOException e) { + log.warn(sm.getString("wsSession.flushFailOnClose"), e); + fireEndpointOnError(e); + } + if (state == State.OPEN) { + state = State.OUTPUT_CLOSED; + fireEndpointOnClose(new CloseReason( + CloseReason.CloseCodes.NORMAL_CLOSURE, "")); + } + state = State.CLOSED; + + // Close the socket + wsRemoteEndpoint.close(); + } + } + } + + + private void fireEndpointOnClose(CloseReason closeReason) { + + // Fire the onClose event + Throwable throwable = null; + InstanceManager instanceManager = webSocketContainer.getInstanceManager(); + if (instanceManager == null) { + instanceManager = InstanceManagerBindings.get(applicationClassLoader); + } + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + localEndpoint.onClose(this, closeReason); + } catch (Throwable t1) { + ExceptionUtils.handleThrowable(t1); + throwable = t1; + } finally { + if (instanceManager != null) { + try { + instanceManager.destroyInstance(localEndpoint); + } catch (Throwable t2) { + ExceptionUtils.handleThrowable(t2); + if (throwable == null) { + throwable = t2; + } + } + } + t.setContextClassLoader(cl); + } + + if (throwable != null) { + fireEndpointOnError(throwable); + } + } + + + private void fireEndpointOnError(Throwable throwable) { + + // Fire the onError event + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + localEndpoint.onError(this, throwable); + } finally { + t.setContextClassLoader(cl); + } + } + + + private void sendCloseMessage(CloseReason closeReason) { + // 125 is maximum size for the payload of a control message + ByteBuffer msg = ByteBuffer.allocate(125); + CloseCode closeCode = closeReason.getCloseCode(); + // CLOSED_ABNORMALLY should not be put on the wire + if (closeCode == CloseCodes.CLOSED_ABNORMALLY) { + // PROTOCOL_ERROR is probably better than GOING_AWAY here + msg.putShort((short) CloseCodes.PROTOCOL_ERROR.getCode()); + } else { + msg.putShort((short) closeCode.getCode()); + } + + String reason = closeReason.getReasonPhrase(); + if (reason != null && reason.length() > 0) { + appendCloseReasonWithTruncation(msg, reason); + } + msg.flip(); + try { + wsRemoteEndpoint.sendMessageBlock(Constants.OPCODE_CLOSE, msg, true); + } catch (IOException | WritePendingException e) { + // Failed to send close message. Close the socket and let the caller + // deal with the Exception + if (log.isDebugEnabled()) { + log.debug(sm.getString("wsSession.sendCloseFail", id), e); + } + wsRemoteEndpoint.close(); + // Failure to send a close message is not unexpected in the case of + // an abnormal closure (usually triggered by a failure to read/write + // from/to the client. In this case do not trigger the endpoint's + // error handling + if (closeCode != CloseCodes.CLOSED_ABNORMALLY) { + localEndpoint.onError(this, e); + } + } finally { + webSocketContainer.unregisterSession(localEndpoint, this); + } + } + + + /** + * Use protected so unit tests can access this method directly. + * @param msg The message + * @param reason The reason + */ + protected static void appendCloseReasonWithTruncation(ByteBuffer msg, String reason) { + // Once the close code has been added there are a maximum of 123 bytes + // left for the reason phrase. If it is truncated then care needs to be + // taken to ensure the bytes are not truncated in the middle of a + // multi-byte UTF-8 character. + byte[] reasonBytes = reason.getBytes(StandardCharsets.UTF_8); + + if (reasonBytes.length <= 123) { + // No need to truncate + msg.put(reasonBytes); + } else { + // Need to truncate + int remaining = 123 - ELLIPSIS_BYTES_LEN; + int pos = 0; + byte[] bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8); + while (remaining >= bytesNext.length) { + msg.put(bytesNext); + remaining -= bytesNext.length; + pos++; + bytesNext = reason.substring(pos, pos + 1).getBytes(StandardCharsets.UTF_8); + } + msg.put(ELLIPSIS_BYTES); + } + } + + + /** + * Make the session aware of a {@link FutureToSendHandler} that will need to + * be forcibly closed if the session closes before the + * {@link FutureToSendHandler} completes. + * @param f2sh The handler + */ + protected void registerFuture(FutureToSendHandler f2sh) { + // Ideally, this code should sync on stateLock so that the correct + // action is taken based on the current state of the connection. + // However, a sync on stateLock can't be used here as it will create the + // possibility of a dead-lock. See BZ 61183. + // Therefore, a slightly less efficient approach is used. + + // Always register the future. + futures.put(f2sh, f2sh); + + if (state == State.OPEN) { + // The session is open. The future has been registered with the open + // session. Normal processing continues. + return; + } + + // The session is closed. The future may or may not have been registered + // in time for it to be processed during session closure. + + if (f2sh.isDone()) { + // The future has completed. It is not known if the future was + // completed normally by the I/O layer or in error by doClose(). It + // doesn't matter which. There is nothing more to do here. + return; + } + + // The session is closed. The Future had not completed when last checked. + // There is a small timing window that means the Future may have been + // completed since the last check. There is also the possibility that + // the Future was not registered in time to be cleaned up during session + // close. + // Attempt to complete the Future with an error result as this ensures + // that the Future completes and any client code waiting on it does not + // hang. It is slightly inefficient since the Future may have been + // completed in another thread or another thread may be about to + // complete the Future but knowing if this is the case requires the sync + // on stateLock (see above). + // Note: If multiple attempts are made to complete the Future, the + // second and subsequent attempts are ignored. + + IOException ioe = new IOException(sm.getString("wsSession.messageFailed")); + SendResult sr = new SendResult(ioe); + f2sh.onResult(sr); + } + + + /** + * Remove a {@link FutureToSendHandler} from the set of tracked instances. + * @param f2sh The handler + */ + protected void unregisterFuture(FutureToSendHandler f2sh) { + futures.remove(f2sh); + } + + + @Override + public URI getRequestURI() { + checkState(); + return requestUri; + } + + + @Override + public Map<String, List<String>> getRequestParameterMap() { + checkState(); + return requestParameterMap; + } + + + @Override + public String getQueryString() { + checkState(); + return queryString; + } + + + @Override + public Principal getUserPrincipal() { + checkState(); + return userPrincipal; + } + + + @Override + public Map<String, String> getPathParameters() { + checkState(); + return pathParameters; + } + + + @Override + public String getId() { + return id; + } + + + @Override + public Map<String, Object> getUserProperties() { + checkState(); + return userProperties; + } + + + public Endpoint getLocal() { + return localEndpoint; + } + + + public String getHttpSessionId() { + return httpSessionId; + } + + private ByteBuffer rawFragments; + + public void processFrame(ByteBuffer buf, byte opCode, boolean last) + throws IOException + { + if (state == State.CLOSED) { + return; + } + + if (opCode == Constants.OPCODE_CONTINUATION) { + opCode = startOpCode; + + if (rawFragments != null && rawFragments.position() > 0) { + rawFragments.put(buf); + rawFragments.flip(); + buf = rawFragments; + } + } else { + if (!last && (opCode == Constants.OPCODE_BINARY || + opCode == Constants.OPCODE_TEXT)) { + startOpCode = opCode; + + if (rawFragments != null) { + rawFragments.clear(); + } + } + } + + if (last) { + startOpCode = Constants.OPCODE_CONTINUATION; + } + + if (opCode == Constants.OPCODE_PONG) { + if (pongMessageHandler != null) { + final ByteBuffer b = buf; + + PongMessage pongMessage = new PongMessage() { + @Override + public ByteBuffer getApplicationData() { + return b; + } + }; + + pongMessageHandler.onMessage(pongMessage); + } + } + + if (opCode == Constants.OPCODE_CLOSE) { + CloseReason closeReason; + + if (buf.remaining() >= 2) { + short closeCode = buf.order(ByteOrder.BIG_ENDIAN).getShort(); + + closeReason = new CloseReason( + CloseReason.CloseCodes.getCloseCode(closeCode), + buf.asCharBuffer().toString()); + } else { + closeReason = new CloseReason( + CloseReason.CloseCodes.NORMAL_CLOSURE, ""); + } + + onClose(closeReason); + } + + if (opCode == Constants.OPCODE_BINARY) { + onMessage(buf, last); + } + + if (opCode == Constants.OPCODE_TEXT) { + if (messageBufferText.position() == 0 && maxTextMessageBufferSize != messageBufferText.capacity()) { + messageBufferText = CharBuffer.allocate(maxTextMessageBufferSize); + } + + CoderResult cr = utf8DecoderMessage.decode(buf, messageBufferText, last); + if (cr.isError()) { + throw new WsIOException(new CloseReason( + CloseCodes.NOT_CONSISTENT, + sm.getString("wsFrame.invalidUtf8"))); + } else if (cr.isOverflow()) { + // Ran out of space in text buffer - flush it + if (hasTextPartial()) { + do { + onMessage(messageBufferText, false); + + cr = utf8DecoderMessage.decode(buf, messageBufferText, last); + } while (cr.isOverflow()); + } else { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + } else if (cr.isUnderflow() && !last) { + updateRawFragments(buf, last); + + if (hasTextPartial()) { + onMessage(messageBufferText, false); + } + + return; + } + + if (last) { + utf8DecoderMessage.reset(); + } + + updateRawFragments(buf, last); + + onMessage(messageBufferText, last); + } + } + + + private boolean hasTextPartial() { + return textMessageHandler instanceof MessageHandler.Partial<?>; + } + + + private void onMessage(CharBuffer buf, boolean last) throws IOException { + buf.flip(); + try { + onMessage(buf.toString(), last); + } catch (Throwable t) { + handleThrowableOnSend(t); + } finally { + buf.clear(); + } + } + + + private void updateRawFragments(ByteBuffer buf, boolean last) { + if (!last && buf.remaining() > 0) { + if (buf == rawFragments) { + buf.compact(); + } else { + if (rawFragments == null || (rawFragments.position() == 0 && maxTextMessageBufferSize != rawFragments.capacity())) { + rawFragments = ByteBuffer.allocateDirect(maxTextMessageBufferSize); + } + rawFragments.put(buf); + } + } else { + if (rawFragments != null) { + rawFragments.clear(); + } + } + } + + + @SuppressWarnings("unchecked") + public void onMessage(String text, boolean last) { + if (hasTextPartial()) { + ((MessageHandler.Partial<String>) textMessageHandler).onMessage(text, last); + } else { + // Caller ensures last == true if this branch is used + ((MessageHandler.Whole<String>) textMessageHandler).onMessage(text); + } + } + + + @SuppressWarnings("unchecked") + public void onMessage(ByteBuffer buf, boolean last) + throws IOException + { + if (binaryMessageHandler instanceof MessageHandler.Partial<?>) { + ((MessageHandler.Partial<ByteBuffer>) binaryMessageHandler).onMessage(buf, last); + } else { + if (last && (binaryBuffer == null || binaryBuffer.position() == 0)) { + ((MessageHandler.Whole<ByteBuffer>) binaryMessageHandler).onMessage(buf); + return; + } + + if (binaryBuffer == null || + (binaryBuffer.position() == 0 && binaryBuffer.capacity() != maxBinaryMessageBufferSize)) + { + binaryBuffer = ByteBuffer.allocateDirect(maxBinaryMessageBufferSize); + } + + if (binaryBuffer.remaining() < buf.remaining()) { + throw new WsIOException(new CloseReason( + CloseCodes.TOO_BIG, + sm.getString("wsFrame.textMessageTooBig"))); + } + + binaryBuffer.put(buf); + + if (last) { + binaryBuffer.flip(); + try { + ((MessageHandler.Whole<ByteBuffer>) binaryMessageHandler).onMessage(binaryBuffer); + } finally { + binaryBuffer.clear(); + } + } + } + } + + + private void handleThrowableOnSend(Throwable t) throws WsIOException { + ExceptionUtils.handleThrowable(t); + getLocal().onError(this, t); + CloseReason cr = new CloseReason(CloseCodes.CLOSED_ABNORMALLY, + sm.getString("wsFrame.ioeTriggeredClose")); + throw new WsIOException(cr); + } + + + protected MessageHandler getTextMessageHandler() { + return textMessageHandler; + } + + + protected MessageHandler getBinaryMessageHandler() { + return binaryMessageHandler; + } + + + protected MessageHandler.Whole<PongMessage> getPongMessageHandler() { + return pongMessageHandler; + } + + + protected void updateLastActive() { + lastActive = System.currentTimeMillis(); + } + + + protected void checkExpiration() { + long timeout = maxIdleTimeout; + if (timeout < 1) { + return; + } + + if (System.currentTimeMillis() - lastActive > timeout) { + String msg = sm.getString("wsSession.timeout", getId()); + if (log.isDebugEnabled()) { + log.debug(msg); + } + doClose(new CloseReason(CloseCodes.GOING_AWAY, msg), + new CloseReason(CloseCodes.CLOSED_ABNORMALLY, msg)); + } + } + + + private void checkState() { + if (state == State.CLOSED) { + /* + * As per RFC 6455, a WebSocket connection is considered to be + * closed once a peer has sent and received a WebSocket close frame. + */ + throw new IllegalStateException(sm.getString("wsSession.closed", id)); + } + } + + private enum State { + OPEN, + OUTPUT_CLOSED, + CLOSED + } +} diff --git a/src/java/nginx/unit/websocket/WsWebSocketContainer.java b/src/java/nginx/unit/websocket/WsWebSocketContainer.java new file mode 100644 index 00000000..282665ef --- /dev/null +++ b/src/java/nginx/unit/websocket/WsWebSocketContainer.java @@ -0,0 +1,1123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.EOFException; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.Proxy; +import java.net.ProxySelector; +import java.net.SocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.channels.AsynchronousChannelGroup; +import java.nio.channels.AsynchronousSocketChannel; +import java.nio.charset.StandardCharsets; +import java.security.KeyStore; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.TrustManagerFactory; +import javax.websocket.ClientEndpoint; +import javax.websocket.ClientEndpointConfig; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.Session; +import javax.websocket.WebSocketContainer; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.util.buf.StringUtils; +import org.apache.tomcat.util.codec.binary.Base64; +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.pojo.PojoEndpointClient; + +public class WsWebSocketContainer implements WebSocketContainer, BackgroundProcess { + + private static final StringManager sm = StringManager.getManager(WsWebSocketContainer.class); + private static final Random RANDOM = new Random(); + private static final byte[] CRLF = new byte[] { 13, 10 }; + + private static final byte[] GET_BYTES = "GET ".getBytes(StandardCharsets.ISO_8859_1); + private static final byte[] ROOT_URI_BYTES = "/".getBytes(StandardCharsets.ISO_8859_1); + private static final byte[] HTTP_VERSION_BYTES = + " HTTP/1.1\r\n".getBytes(StandardCharsets.ISO_8859_1); + + private volatile AsynchronousChannelGroup asynchronousChannelGroup = null; + private final Object asynchronousChannelGroupLock = new Object(); + + private final Log log = LogFactory.getLog(WsWebSocketContainer.class); // must not be static + private final Map<Endpoint, Set<WsSession>> endpointSessionMap = + new HashMap<>(); + private final Map<WsSession,WsSession> sessions = new ConcurrentHashMap<>(); + private final Object endPointSessionMapLock = new Object(); + + private long defaultAsyncTimeout = -1; + private int maxBinaryMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private int maxTextMessageBufferSize = Constants.DEFAULT_BUFFER_SIZE; + private volatile long defaultMaxSessionIdleTimeout = 0; + private int backgroundProcessCount = 0; + private int processPeriod = Constants.DEFAULT_PROCESS_PERIOD; + + private InstanceManager instanceManager; + + InstanceManager getInstanceManager() { + return instanceManager; + } + + protected void setInstanceManager(InstanceManager instanceManager) { + this.instanceManager = instanceManager; + } + + @Override + public Session connectToServer(Object pojo, URI path) + throws DeploymentException { + + ClientEndpoint annotation = + pojo.getClass().getAnnotation(ClientEndpoint.class); + if (annotation == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.missingAnnotation", + pojo.getClass().getName())); + } + + Endpoint ep = new PojoEndpointClient(pojo, Arrays.asList(annotation.decoders())); + + Class<? extends ClientEndpointConfig.Configurator> configuratorClazz = + annotation.configurator(); + + ClientEndpointConfig.Configurator configurator = null; + if (!ClientEndpointConfig.Configurator.class.equals( + configuratorClazz)) { + try { + configurator = configuratorClazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.defaultConfiguratorFail"), e); + } + } + + ClientEndpointConfig.Builder builder = ClientEndpointConfig.Builder.create(); + // Avoid NPE when using RI API JAR - see BZ 56343 + if (configurator != null) { + builder.configurator(configurator); + } + ClientEndpointConfig config = builder. + decoders(Arrays.asList(annotation.decoders())). + encoders(Arrays.asList(annotation.encoders())). + preferredSubprotocols(Arrays.asList(annotation.subprotocols())). + build(); + return connectToServer(ep, config, path); + } + + + @Override + public Session connectToServer(Class<?> annotatedEndpointClass, URI path) + throws DeploymentException { + + Object pojo; + try { + pojo = annotatedEndpointClass.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.endpointCreateFail", + annotatedEndpointClass.getName()), e); + } + + return connectToServer(pojo, path); + } + + + @Override + public Session connectToServer(Class<? extends Endpoint> clazz, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException { + + Endpoint endpoint; + try { + endpoint = clazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.endpointCreateFail", clazz.getName()), + e); + } + + return connectToServer(endpoint, clientEndpointConfiguration, path); + } + + + @Override + public Session connectToServer(Endpoint endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path) + throws DeploymentException { + return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, new HashSet<>()); + } + + private Session connectToServerRecursive(Endpoint endpoint, + ClientEndpointConfig clientEndpointConfiguration, URI path, + Set<URI> redirectSet) + throws DeploymentException { + + boolean secure = false; + ByteBuffer proxyConnect = null; + URI proxyPath; + + // Validate scheme (and build proxyPath) + String scheme = path.getScheme(); + if ("ws".equalsIgnoreCase(scheme)) { + proxyPath = URI.create("http" + path.toString().substring(2)); + } else if ("wss".equalsIgnoreCase(scheme)) { + proxyPath = URI.create("https" + path.toString().substring(3)); + secure = true; + } else { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.pathWrongScheme", scheme)); + } + + // Validate host + String host = path.getHost(); + if (host == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.pathNoHost")); + } + int port = path.getPort(); + + SocketAddress sa = null; + + // Check to see if a proxy is configured. Javadoc indicates return value + // will never be null + List<Proxy> proxies = ProxySelector.getDefault().select(proxyPath); + Proxy selectedProxy = null; + for (Proxy proxy : proxies) { + if (proxy.type().equals(Proxy.Type.HTTP)) { + sa = proxy.address(); + if (sa instanceof InetSocketAddress) { + InetSocketAddress inet = (InetSocketAddress) sa; + if (inet.isUnresolved()) { + sa = new InetSocketAddress(inet.getHostName(), inet.getPort()); + } + } + selectedProxy = proxy; + break; + } + } + + // If the port is not explicitly specified, compute it based on the + // scheme + if (port == -1) { + if ("ws".equalsIgnoreCase(scheme)) { + port = 80; + } else { + // Must be wss due to scheme validation above + port = 443; + } + } + + // If sa is null, no proxy is configured so need to create sa + if (sa == null) { + sa = new InetSocketAddress(host, port); + } else { + proxyConnect = createProxyRequest(host, port); + } + + // Create the initial HTTP request to open the WebSocket connection + Map<String, List<String>> reqHeaders = createRequestHeaders(host, port, + clientEndpointConfiguration); + clientEndpointConfiguration.getConfigurator().beforeRequest(reqHeaders); + if (Constants.DEFAULT_ORIGIN_HEADER_VALUE != null + && !reqHeaders.containsKey(Constants.ORIGIN_HEADER_NAME)) { + List<String> originValues = new ArrayList<>(1); + originValues.add(Constants.DEFAULT_ORIGIN_HEADER_VALUE); + reqHeaders.put(Constants.ORIGIN_HEADER_NAME, originValues); + } + ByteBuffer request = createRequest(path, reqHeaders); + + AsynchronousSocketChannel socketChannel; + try { + socketChannel = AsynchronousSocketChannel.open(getAsynchronousChannelGroup()); + } catch (IOException ioe) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.asynchronousSocketChannelFail"), ioe); + } + + Map<String,Object> userProperties = clientEndpointConfiguration.getUserProperties(); + + // Get the connection timeout + long timeout = Constants.IO_TIMEOUT_MS_DEFAULT; + String timeoutValue = (String) userProperties.get(Constants.IO_TIMEOUT_MS_PROPERTY); + if (timeoutValue != null) { + timeout = Long.valueOf(timeoutValue).intValue(); + } + + // Set-up + // Same size as the WsFrame input buffer + ByteBuffer response = ByteBuffer.allocate(getDefaultMaxBinaryMessageBufferSize()); + String subProtocol; + boolean success = false; + List<Extension> extensionsAgreed = new ArrayList<>(); + Transformation transformation = null; + + // Open the connection + Future<Void> fConnect = socketChannel.connect(sa); + AsyncChannelWrapper channel = null; + + if (proxyConnect != null) { + try { + fConnect.get(timeout, TimeUnit.MILLISECONDS); + // Proxy CONNECT is clear text + channel = new AsyncChannelWrapperNonSecure(socketChannel); + writeRequest(channel, proxyConnect, timeout); + HttpResponse httpResponse = processResponse(response, channel, timeout); + if (httpResponse.getStatus() != 200) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.proxyConnectFail", selectedProxy, + Integer.toString(httpResponse.getStatus()))); + } + } catch (TimeoutException | InterruptedException | ExecutionException | + EOFException e) { + if (channel != null) { + channel.close(); + } + throw new DeploymentException( + sm.getString("wsWebSocketContainer.httpRequestFailed"), e); + } + } + + if (secure) { + // Regardless of whether a non-secure wrapper was created for a + // proxy CONNECT, need to use TLS from this point on so wrap the + // original AsynchronousSocketChannel + SSLEngine sslEngine = createSSLEngine(userProperties, host, port); + channel = new AsyncChannelWrapperSecure(socketChannel, sslEngine); + } else if (channel == null) { + // Only need to wrap as this point if it wasn't wrapped to process a + // proxy CONNECT + channel = new AsyncChannelWrapperNonSecure(socketChannel); + } + + try { + fConnect.get(timeout, TimeUnit.MILLISECONDS); + + Future<Void> fHandshake = channel.handshake(); + fHandshake.get(timeout, TimeUnit.MILLISECONDS); + + writeRequest(channel, request, timeout); + + HttpResponse httpResponse = processResponse(response, channel, timeout); + + // Check maximum permitted redirects + int maxRedirects = Constants.MAX_REDIRECTIONS_DEFAULT; + String maxRedirectsValue = + (String) userProperties.get(Constants.MAX_REDIRECTIONS_PROPERTY); + if (maxRedirectsValue != null) { + maxRedirects = Integer.parseInt(maxRedirectsValue); + } + + if (httpResponse.status != 101) { + if(isRedirectStatus(httpResponse.status)){ + List<String> locationHeader = + httpResponse.getHandshakeResponse().getHeaders().get( + Constants.LOCATION_HEADER_NAME); + + if (locationHeader == null || locationHeader.isEmpty() || + locationHeader.get(0) == null || locationHeader.get(0).isEmpty()) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.missingLocationHeader", + Integer.toString(httpResponse.status))); + } + + URI redirectLocation = URI.create(locationHeader.get(0)).normalize(); + + if (!redirectLocation.isAbsolute()) { + redirectLocation = path.resolve(redirectLocation); + } + + String redirectScheme = redirectLocation.getScheme().toLowerCase(); + + if (redirectScheme.startsWith("http")) { + redirectLocation = new URI(redirectScheme.replace("http", "ws"), + redirectLocation.getUserInfo(), redirectLocation.getHost(), + redirectLocation.getPort(), redirectLocation.getPath(), + redirectLocation.getQuery(), redirectLocation.getFragment()); + } + + if (!redirectSet.add(redirectLocation) || redirectSet.size() > maxRedirects) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.redirectThreshold", redirectLocation, + Integer.toString(redirectSet.size()), + Integer.toString(maxRedirects))); + } + + return connectToServerRecursive(endpoint, clientEndpointConfiguration, redirectLocation, redirectSet); + + } + + else if (httpResponse.status == 401) { + + if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.failedAuthentication", + Integer.valueOf(httpResponse.status))); + } + + List<String> wwwAuthenticateHeaders = httpResponse.getHandshakeResponse() + .getHeaders().get(Constants.WWW_AUTHENTICATE_HEADER_NAME); + + if (wwwAuthenticateHeaders == null || wwwAuthenticateHeaders.isEmpty() || + wwwAuthenticateHeaders.get(0) == null || wwwAuthenticateHeaders.get(0).isEmpty()) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.missingWWWAuthenticateHeader", + Integer.toString(httpResponse.status))); + } + + String authScheme = wwwAuthenticateHeaders.get(0).split("\\s+", 2)[0]; + String requestUri = new String(request.array(), StandardCharsets.ISO_8859_1) + .split("\\s", 3)[1]; + + Authenticator auth = AuthenticatorFactory.getAuthenticator(authScheme); + + if (auth == null) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.unsupportedAuthScheme", + Integer.valueOf(httpResponse.status), authScheme)); + } + + userProperties.put(Constants.AUTHORIZATION_HEADER_NAME, auth.getAuthorization( + requestUri, wwwAuthenticateHeaders.get(0), userProperties)); + + return connectToServerRecursive(endpoint, clientEndpointConfiguration, path, redirectSet); + + } + + else { + throw new DeploymentException(sm.getString("wsWebSocketContainer.invalidStatus", + Integer.toString(httpResponse.status))); + } + } + HandshakeResponse handshakeResponse = httpResponse.getHandshakeResponse(); + clientEndpointConfiguration.getConfigurator().afterResponse(handshakeResponse); + + // Sub-protocol + List<String> protocolHeaders = handshakeResponse.getHeaders().get( + Constants.WS_PROTOCOL_HEADER_NAME); + if (protocolHeaders == null || protocolHeaders.size() == 0) { + subProtocol = null; + } else if (protocolHeaders.size() == 1) { + subProtocol = protocolHeaders.get(0); + } else { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.invalidSubProtocol")); + } + + // Extensions + // Should normally only be one header but handle the case of + // multiple headers + List<String> extHeaders = handshakeResponse.getHeaders().get( + Constants.WS_EXTENSIONS_HEADER_NAME); + if (extHeaders != null) { + for (String extHeader : extHeaders) { + Util.parseExtensionHeader(extensionsAgreed, extHeader); + } + } + + // Build the transformations + TransformationFactory factory = TransformationFactory.getInstance(); + for (Extension extension : extensionsAgreed) { + List<List<Extension.Parameter>> wrapper = new ArrayList<>(1); + wrapper.add(extension.getParameters()); + Transformation t = factory.create(extension.getName(), wrapper, false); + if (t == null) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidExtensionParameters")); + } + if (transformation == null) { + transformation = t; + } else { + transformation.setNext(t); + } + } + + success = true; + } catch (ExecutionException | InterruptedException | SSLException | + EOFException | TimeoutException | URISyntaxException | AuthenticationException e) { + throw new DeploymentException( + sm.getString("wsWebSocketContainer.httpRequestFailed"), e); + } finally { + if (!success) { + channel.close(); + } + } + + // Switch to WebSocket + WsRemoteEndpointImplClient wsRemoteEndpointClient = new WsRemoteEndpointImplClient(channel); + + WsSession wsSession = new WsSession(endpoint, wsRemoteEndpointClient, + this, null, null, null, null, null, extensionsAgreed, + subProtocol, Collections.<String,String>emptyMap(), secure, + clientEndpointConfiguration, null); + + WsFrameClient wsFrameClient = new WsFrameClient(response, channel, + wsSession, transformation); + // WsFrame adds the necessary final transformations. Copy the + // completed transformation chain to the remote end point. + wsRemoteEndpointClient.setTransformation(wsFrameClient.getTransformation()); + + endpoint.onOpen(wsSession, clientEndpointConfiguration); + registerSession(endpoint, wsSession); + + /* It is possible that the server sent one or more messages as soon as + * the WebSocket connection was established. Depending on the exact + * timing of when those messages were sent they could be sat in the + * input buffer waiting to be read and will not trigger a "data + * available to read" event. Therefore, it is necessary to process the + * input buffer here. Note that this happens on the current thread which + * means that this thread will be used for any onMessage notifications. + * This is a special case. Subsequent "data available to read" events + * will be handled by threads from the AsyncChannelGroup's executor. + */ + wsFrameClient.startInputProcessing(); + + return wsSession; + } + + + private static void writeRequest(AsyncChannelWrapper channel, ByteBuffer request, + long timeout) throws TimeoutException, InterruptedException, ExecutionException { + int toWrite = request.limit(); + + Future<Integer> fWrite = channel.write(request); + Integer thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); + toWrite -= thisWrite.intValue(); + + while (toWrite > 0) { + fWrite = channel.write(request); + thisWrite = fWrite.get(timeout, TimeUnit.MILLISECONDS); + toWrite -= thisWrite.intValue(); + } + } + + + private static boolean isRedirectStatus(int httpResponseCode) { + + boolean isRedirect = false; + + switch (httpResponseCode) { + case Constants.MULTIPLE_CHOICES: + case Constants.MOVED_PERMANENTLY: + case Constants.FOUND: + case Constants.SEE_OTHER: + case Constants.USE_PROXY: + case Constants.TEMPORARY_REDIRECT: + isRedirect = true; + break; + default: + break; + } + + return isRedirect; + } + + + private static ByteBuffer createProxyRequest(String host, int port) { + StringBuilder request = new StringBuilder(); + request.append("CONNECT "); + request.append(host); + request.append(':'); + request.append(port); + + request.append(" HTTP/1.1\r\nProxy-Connection: keep-alive\r\nConnection: keepalive\r\nHost: "); + request.append(host); + request.append(':'); + request.append(port); + + request.append("\r\n\r\n"); + + byte[] bytes = request.toString().getBytes(StandardCharsets.ISO_8859_1); + return ByteBuffer.wrap(bytes); + } + + protected void registerSession(Endpoint endpoint, WsSession wsSession) { + + if (!wsSession.isOpen()) { + // The session was closed during onOpen. No need to register it. + return; + } + synchronized (endPointSessionMapLock) { + if (endpointSessionMap.size() == 0) { + BackgroundProcessManager.getInstance().register(this); + } + Set<WsSession> wsSessions = endpointSessionMap.get(endpoint); + if (wsSessions == null) { + wsSessions = new HashSet<>(); + endpointSessionMap.put(endpoint, wsSessions); + } + wsSessions.add(wsSession); + } + sessions.put(wsSession, wsSession); + } + + + protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { + + synchronized (endPointSessionMapLock) { + Set<WsSession> wsSessions = endpointSessionMap.get(endpoint); + if (wsSessions != null) { + wsSessions.remove(wsSession); + if (wsSessions.size() == 0) { + endpointSessionMap.remove(endpoint); + } + } + if (endpointSessionMap.size() == 0) { + BackgroundProcessManager.getInstance().unregister(this); + } + } + sessions.remove(wsSession); + } + + + Set<Session> getOpenSessions(Endpoint endpoint) { + HashSet<Session> result = new HashSet<>(); + synchronized (endPointSessionMapLock) { + Set<WsSession> sessions = endpointSessionMap.get(endpoint); + if (sessions != null) { + result.addAll(sessions); + } + } + return result; + } + + private static Map<String, List<String>> createRequestHeaders(String host, int port, + ClientEndpointConfig clientEndpointConfiguration) { + + Map<String, List<String>> headers = new HashMap<>(); + List<Extension> extensions = clientEndpointConfiguration.getExtensions(); + List<String> subProtocols = clientEndpointConfiguration.getPreferredSubprotocols(); + Map<String, Object> userProperties = clientEndpointConfiguration.getUserProperties(); + + if (userProperties.get(Constants.AUTHORIZATION_HEADER_NAME) != null) { + List<String> authValues = new ArrayList<>(1); + authValues.add((String) userProperties.get(Constants.AUTHORIZATION_HEADER_NAME)); + headers.put(Constants.AUTHORIZATION_HEADER_NAME, authValues); + } + + // Host header + List<String> hostValues = new ArrayList<>(1); + if (port == -1) { + hostValues.add(host); + } else { + hostValues.add(host + ':' + port); + } + + headers.put(Constants.HOST_HEADER_NAME, hostValues); + + // Upgrade header + List<String> upgradeValues = new ArrayList<>(1); + upgradeValues.add(Constants.UPGRADE_HEADER_VALUE); + headers.put(Constants.UPGRADE_HEADER_NAME, upgradeValues); + + // Connection header + List<String> connectionValues = new ArrayList<>(1); + connectionValues.add(Constants.CONNECTION_HEADER_VALUE); + headers.put(Constants.CONNECTION_HEADER_NAME, connectionValues); + + // WebSocket version header + List<String> wsVersionValues = new ArrayList<>(1); + wsVersionValues.add(Constants.WS_VERSION_HEADER_VALUE); + headers.put(Constants.WS_VERSION_HEADER_NAME, wsVersionValues); + + // WebSocket key + List<String> wsKeyValues = new ArrayList<>(1); + wsKeyValues.add(generateWsKeyValue()); + headers.put(Constants.WS_KEY_HEADER_NAME, wsKeyValues); + + // WebSocket sub-protocols + if (subProtocols != null && subProtocols.size() > 0) { + headers.put(Constants.WS_PROTOCOL_HEADER_NAME, subProtocols); + } + + // WebSocket extensions + if (extensions != null && extensions.size() > 0) { + headers.put(Constants.WS_EXTENSIONS_HEADER_NAME, + generateExtensionHeaders(extensions)); + } + + return headers; + } + + + private static List<String> generateExtensionHeaders(List<Extension> extensions) { + List<String> result = new ArrayList<>(extensions.size()); + for (Extension extension : extensions) { + StringBuilder header = new StringBuilder(); + header.append(extension.getName()); + for (Extension.Parameter param : extension.getParameters()) { + header.append(';'); + header.append(param.getName()); + String value = param.getValue(); + if (value != null && value.length() > 0) { + header.append('='); + header.append(value); + } + } + result.add(header.toString()); + } + return result; + } + + + private static String generateWsKeyValue() { + byte[] keyBytes = new byte[16]; + RANDOM.nextBytes(keyBytes); + return Base64.encodeBase64String(keyBytes); + } + + + private static ByteBuffer createRequest(URI uri, Map<String,List<String>> reqHeaders) { + ByteBuffer result = ByteBuffer.allocate(4 * 1024); + + // Request line + result.put(GET_BYTES); + if (null == uri.getPath() || "".equals(uri.getPath())) { + result.put(ROOT_URI_BYTES); + } else { + result.put(uri.getRawPath().getBytes(StandardCharsets.ISO_8859_1)); + } + String query = uri.getRawQuery(); + if (query != null) { + result.put((byte) '?'); + result.put(query.getBytes(StandardCharsets.ISO_8859_1)); + } + result.put(HTTP_VERSION_BYTES); + + // Headers + for (Entry<String, List<String>> entry : reqHeaders.entrySet()) { + result = addHeader(result, entry.getKey(), entry.getValue()); + } + + // Terminating CRLF + result.put(CRLF); + + result.flip(); + + return result; + } + + + private static ByteBuffer addHeader(ByteBuffer result, String key, List<String> values) { + if (values.isEmpty()) { + return result; + } + + result = putWithExpand(result, key.getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, ": ".getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, StringUtils.join(values).getBytes(StandardCharsets.ISO_8859_1)); + result = putWithExpand(result, CRLF); + + return result; + } + + + private static ByteBuffer putWithExpand(ByteBuffer input, byte[] bytes) { + if (bytes.length > input.remaining()) { + int newSize; + if (bytes.length > input.capacity()) { + newSize = 2 * bytes.length; + } else { + newSize = input.capacity() * 2; + } + ByteBuffer expanded = ByteBuffer.allocate(newSize); + input.flip(); + expanded.put(input); + input = expanded; + } + return input.put(bytes); + } + + + /** + * Process response, blocking until HTTP response has been fully received. + * @throws ExecutionException + * @throws InterruptedException + * @throws DeploymentException + * @throws TimeoutException + */ + private HttpResponse processResponse(ByteBuffer response, + AsyncChannelWrapper channel, long timeout) throws InterruptedException, + ExecutionException, DeploymentException, EOFException, + TimeoutException { + + Map<String,List<String>> headers = new CaseInsensitiveKeyMap<>(); + + int status = 0; + boolean readStatus = false; + boolean readHeaders = false; + String line = null; + while (!readHeaders) { + // On entering loop buffer will be empty and at the start of a new + // loop the buffer will have been fully read. + response.clear(); + // Blocking read + Future<Integer> read = channel.read(response); + Integer bytesRead = read.get(timeout, TimeUnit.MILLISECONDS); + if (bytesRead.intValue() == -1) { + throw new EOFException(); + } + response.flip(); + while (response.hasRemaining() && !readHeaders) { + if (line == null) { + line = readLine(response); + } else { + line += readLine(response); + } + if ("\r\n".equals(line)) { + readHeaders = true; + } else if (line.endsWith("\r\n")) { + if (readStatus) { + parseHeaders(line, headers); + } else { + status = parseStatus(line); + readStatus = true; + } + line = null; + } + } + } + + return new HttpResponse(status, new WsHandshakeResponse(headers)); + } + + + private int parseStatus(String line) throws DeploymentException { + // This client only understands HTTP 1. + // RFC2616 is case specific + String[] parts = line.trim().split(" "); + // CONNECT for proxy may return a 1.0 response + if (parts.length < 2 || !("HTTP/1.0".equals(parts[0]) || "HTTP/1.1".equals(parts[0]))) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidStatus", line)); + } + try { + return Integer.parseInt(parts[1]); + } catch (NumberFormatException nfe) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.invalidStatus", line)); + } + } + + + private void parseHeaders(String line, Map<String,List<String>> headers) { + // Treat headers as single values by default. + + int index = line.indexOf(':'); + if (index == -1) { + log.warn(sm.getString("wsWebSocketContainer.invalidHeader", line)); + return; + } + // Header names are case insensitive so always use lower case + String headerName = line.substring(0, index).trim().toLowerCase(Locale.ENGLISH); + // Multi-value headers are stored as a single header and the client is + // expected to handle splitting into individual values + String headerValue = line.substring(index + 1).trim(); + + List<String> values = headers.get(headerName); + if (values == null) { + values = new ArrayList<>(1); + headers.put(headerName, values); + } + values.add(headerValue); + } + + private String readLine(ByteBuffer response) { + // All ISO-8859-1 + StringBuilder sb = new StringBuilder(); + + char c = 0; + while (response.hasRemaining()) { + c = (char) response.get(); + sb.append(c); + if (c == 10) { + break; + } + } + + return sb.toString(); + } + + + private SSLEngine createSSLEngine(Map<String,Object> userProperties, String host, int port) + throws DeploymentException { + + try { + // See if a custom SSLContext has been provided + SSLContext sslContext = + (SSLContext) userProperties.get(Constants.SSL_CONTEXT_PROPERTY); + + if (sslContext == null) { + // Create the SSL Context + sslContext = SSLContext.getInstance("TLS"); + + // Trust store + String sslTrustStoreValue = + (String) userProperties.get(Constants.SSL_TRUSTSTORE_PROPERTY); + if (sslTrustStoreValue != null) { + String sslTrustStorePwdValue = (String) userProperties.get( + Constants.SSL_TRUSTSTORE_PWD_PROPERTY); + if (sslTrustStorePwdValue == null) { + sslTrustStorePwdValue = Constants.SSL_TRUSTSTORE_PWD_DEFAULT; + } + + File keyStoreFile = new File(sslTrustStoreValue); + KeyStore ks = KeyStore.getInstance("JKS"); + try (InputStream is = new FileInputStream(keyStoreFile)) { + ks.load(is, sslTrustStorePwdValue.toCharArray()); + } + + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init(ks); + + sslContext.init(null, tmf.getTrustManagers(), null); + } else { + sslContext.init(null, null, null); + } + } + + SSLEngine engine = sslContext.createSSLEngine(host, port); + + String sslProtocolsValue = + (String) userProperties.get(Constants.SSL_PROTOCOLS_PROPERTY); + if (sslProtocolsValue != null) { + engine.setEnabledProtocols(sslProtocolsValue.split(",")); + } + + engine.setUseClientMode(true); + + // Enable host verification + // Start with current settings (returns a copy) + SSLParameters sslParams = engine.getSSLParameters(); + // Use HTTPS since WebSocket starts over HTTP(S) + sslParams.setEndpointIdentificationAlgorithm("HTTPS"); + // Write the parameters back + engine.setSSLParameters(sslParams); + + return engine; + } catch (Exception e) { + throw new DeploymentException(sm.getString( + "wsWebSocketContainer.sslEngineFail"), e); + } + } + + + @Override + public long getDefaultMaxSessionIdleTimeout() { + return defaultMaxSessionIdleTimeout; + } + + + @Override + public void setDefaultMaxSessionIdleTimeout(long timeout) { + this.defaultMaxSessionIdleTimeout = timeout; + } + + + @Override + public int getDefaultMaxBinaryMessageBufferSize() { + return maxBinaryMessageBufferSize; + } + + + @Override + public void setDefaultMaxBinaryMessageBufferSize(int max) { + maxBinaryMessageBufferSize = max; + } + + + @Override + public int getDefaultMaxTextMessageBufferSize() { + return maxTextMessageBufferSize; + } + + + @Override + public void setDefaultMaxTextMessageBufferSize(int max) { + maxTextMessageBufferSize = max; + } + + + /** + * {@inheritDoc} + * + * Currently, this implementation does not support any extensions. + */ + @Override + public Set<Extension> getInstalledExtensions() { + return Collections.emptySet(); + } + + + /** + * {@inheritDoc} + * + * The default value for this implementation is -1. + */ + @Override + public long getDefaultAsyncSendTimeout() { + return defaultAsyncTimeout; + } + + + /** + * {@inheritDoc} + * + * The default value for this implementation is -1. + */ + @Override + public void setAsyncSendTimeout(long timeout) { + this.defaultAsyncTimeout = timeout; + } + + + /** + * Cleans up the resources still in use by WebSocket sessions created from + * this container. This includes closing sessions and cancelling + * {@link Future}s associated with blocking read/writes. + */ + public void destroy() { + CloseReason cr = new CloseReason( + CloseCodes.GOING_AWAY, sm.getString("wsWebSocketContainer.shutdown")); + + for (WsSession session : sessions.keySet()) { + try { + session.close(cr); + } catch (IOException ioe) { + log.debug(sm.getString( + "wsWebSocketContainer.sessionCloseFail", session.getId()), ioe); + } + } + + // Only unregister with AsyncChannelGroupUtil if this instance + // registered with it + if (asynchronousChannelGroup != null) { + synchronized (asynchronousChannelGroupLock) { + if (asynchronousChannelGroup != null) { + AsyncChannelGroupUtil.unregister(); + asynchronousChannelGroup = null; + } + } + } + } + + + private AsynchronousChannelGroup getAsynchronousChannelGroup() { + // Use AsyncChannelGroupUtil to share a common group amongst all + // WebSocket clients + AsynchronousChannelGroup result = asynchronousChannelGroup; + if (result == null) { + synchronized (asynchronousChannelGroupLock) { + if (asynchronousChannelGroup == null) { + asynchronousChannelGroup = AsyncChannelGroupUtil.register(); + } + result = asynchronousChannelGroup; + } + } + return result; + } + + + // ----------------------------------------------- BackgroundProcess methods + + @Override + public void backgroundProcess() { + // This method gets called once a second. + backgroundProcessCount ++; + if (backgroundProcessCount >= processPeriod) { + backgroundProcessCount = 0; + + for (WsSession wsSession : sessions.keySet()) { + wsSession.checkExpiration(); + } + } + + } + + + @Override + public void setProcessPeriod(int period) { + this.processPeriod = period; + } + + + /** + * {@inheritDoc} + * + * The default value is 10 which means session expirations are processed + * every 10 seconds. + */ + @Override + public int getProcessPeriod() { + return processPeriod; + } + + + private static class HttpResponse { + private final int status; + private final HandshakeResponse handshakeResponse; + + public HttpResponse(int status, HandshakeResponse handshakeResponse) { + this.status = status; + this.handshakeResponse = handshakeResponse; + } + + + public int getStatus() { + return status; + } + + + public HandshakeResponse getHandshakeResponse() { + return handshakeResponse; + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/Constants.java b/src/java/nginx/unit/websocket/pojo/Constants.java new file mode 100644 index 00000000..93cdecc7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/Constants.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +/** + * Internal implementation constants. + */ +public class Constants { + + public static final String POJO_PATH_PARAM_KEY = + "nginx.unit.websocket.pojo.PojoEndpoint.pathParams"; + public static final String POJO_METHOD_MAPPING_KEY = + "nginx.unit.websocket.pojo.PojoEndpoint.methodMapping"; + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/pojo/LocalStrings.properties b/src/java/nginx/unit/websocket/pojo/LocalStrings.properties new file mode 100644 index 00000000..00ab7e6b --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/LocalStrings.properties @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +pojoEndpointBase.closeSessionFail=Failed to close WebSocket session during error handling +pojoEndpointBase.onCloseFail=Failed to call onClose method of POJO end point for POJO of type [{0}] +pojoEndpointBase.onError=No error handling configured for [{0}] and the following error occurred +pojoEndpointBase.onErrorFail=Failed to call onError method of POJO end point for POJO of type [{0}] +pojoEndpointBase.onOpenFail=Failed to call onOpen method of POJO end point for POJO of type [{0}] +pojoEndpointServer.getPojoInstanceFail=Failed to create instance of POJO of type [{0}] +pojoMethodMapping.decodePathParamFail=Failed to decode path parameter value [{0}] to expected type [{1}] +pojoMethodMapping.duplicateAnnotation=Duplicate annotations [{0}] present on class [{1}] +pojoMethodMapping.duplicateLastParam=Multiple boolean (last) parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicateMessageParam=Multiple message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicatePongMessageParam=Multiple PongMessage parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.duplicateSessionParam=Multiple session parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.invalidDecoder=The specified decoder of type [{0}] could not be instantiated +pojoMethodMapping.invalidPathParamType=Parameters annotated with @PathParam may only be Strings, Java primitives or a boxed version thereof +pojoMethodMapping.methodNotPublic=The annotated method [{0}] is not public +pojoMethodMapping.noPayload=No payload parameter present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.onErrorNoThrowable=No Throwable parameter was present on the method [{0}] of class [{1}] that was annotated with OnError +pojoMethodMapping.paramWithoutAnnotation=A parameter of type [{0}] was found on method[{1}] of class [{2}] that did not have a @PathParam annotation +pojoMethodMapping.partialInputStream=Invalid InputStream and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialObject=Invalid Object and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialPong=Invalid PongMessage and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.partialReader=Invalid Reader and boolean parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMethodMapping.pongWithPayload=Invalid PongMessage and Message parameters present on the method [{0}] of class [{1}] that was annotated with OnMessage +pojoMessageHandlerWhole.decodeIoFail=IO error while decoding message +pojoMessageHandlerWhole.maxBufferSize=The maximum supported message size for this implementation is Integer.MAX_VALUE diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java new file mode 100644 index 00000000..be679a35 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointBase.java @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.util.Map; +import java.util.Set; + +import javax.websocket.CloseReason; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.ExceptionUtils; +import org.apache.tomcat.util.res.StringManager; + +/** + * Base implementation (client and server have different concrete + * implementations) of the wrapper that converts a POJO instance into a + * WebSocket endpoint instance. + */ +public abstract class PojoEndpointBase extends Endpoint { + + private final Log log = LogFactory.getLog(PojoEndpointBase.class); // must not be static + private static final StringManager sm = StringManager.getManager(PojoEndpointBase.class); + + private Object pojo; + private Map<String,String> pathParameters; + private PojoMethodMapping methodMapping; + + + protected final void doOnOpen(Session session, EndpointConfig config) { + PojoMethodMapping methodMapping = getMethodMapping(); + Object pojo = getPojo(); + Map<String,String> pathParameters = getPathParameters(); + + // Add message handlers before calling onOpen since that may trigger a + // message which in turn could trigger a response and/or close the + // session + for (MessageHandler mh : methodMapping.getMessageHandlers(pojo, + pathParameters, session, config)) { + session.addMessageHandler(mh); + } + + if (methodMapping.getOnOpen() != null) { + try { + methodMapping.getOnOpen().invoke(pojo, + methodMapping.getOnOpenArgs( + pathParameters, session, config)); + + } catch (IllegalAccessException e) { + // Reflection related problems + log.error(sm.getString( + "pojoEndpointBase.onOpenFail", + pojo.getClass().getName()), e); + handleOnOpenOrCloseError(session, e); + } catch (InvocationTargetException e) { + Throwable cause = e.getCause(); + handleOnOpenOrCloseError(session, cause); + } catch (Throwable t) { + handleOnOpenOrCloseError(session, t); + } + } + } + + + private void handleOnOpenOrCloseError(Session session, Throwable t) { + // If really fatal - re-throw + ExceptionUtils.handleThrowable(t); + + // Trigger the error handler and close the session + onError(session, t); + try { + session.close(); + } catch (IOException ioe) { + log.warn(sm.getString("pojoEndpointBase.closeSessionFail"), ioe); + } + } + + @Override + public final void onClose(Session session, CloseReason closeReason) { + + if (methodMapping.getOnClose() != null) { + try { + methodMapping.getOnClose().invoke(pojo, + methodMapping.getOnCloseArgs(pathParameters, session, closeReason)); + } catch (Throwable t) { + log.error(sm.getString("pojoEndpointBase.onCloseFail", + pojo.getClass().getName()), t); + handleOnOpenOrCloseError(session, t); + } + } + + // Trigger the destroy method for any associated decoders + Set<MessageHandler> messageHandlers = session.getMessageHandlers(); + for (MessageHandler messageHandler : messageHandlers) { + if (messageHandler instanceof PojoMessageHandlerWholeBase<?>) { + ((PojoMessageHandlerWholeBase<?>) messageHandler).onClose(); + } + } + } + + + @Override + public final void onError(Session session, Throwable throwable) { + + if (methodMapping.getOnError() == null) { + log.error(sm.getString("pojoEndpointBase.onError", + pojo.getClass().getName()), throwable); + } else { + try { + methodMapping.getOnError().invoke( + pojo, + methodMapping.getOnErrorArgs(pathParameters, session, + throwable)); + } catch (Throwable t) { + ExceptionUtils.handleThrowable(t); + log.error(sm.getString("pojoEndpointBase.onErrorFail", + pojo.getClass().getName()), t); + } + } + } + + protected Object getPojo() { return pojo; } + protected void setPojo(Object pojo) { this.pojo = pojo; } + + + protected Map<String,String> getPathParameters() { return pathParameters; } + protected void setPathParameters(Map<String,String> pathParameters) { + this.pathParameters = pathParameters; + } + + + protected PojoMethodMapping getMethodMapping() { return methodMapping; } + protected void setMethodMapping(PojoMethodMapping methodMapping) { + this.methodMapping = methodMapping; + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java new file mode 100644 index 00000000..6e569487 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointClient.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.util.Collections; +import java.util.List; + +import javax.websocket.Decoder; +import javax.websocket.DeploymentException; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + + +/** + * Wrapper class for instances of POJOs annotated with + * {@link javax.websocket.ClientEndpoint} so they appear as standard + * {@link javax.websocket.Endpoint} instances. + */ +public class PojoEndpointClient extends PojoEndpointBase { + + public PojoEndpointClient(Object pojo, + List<Class<? extends Decoder>> decoders) throws DeploymentException { + setPojo(pojo); + setMethodMapping( + new PojoMethodMapping(pojo.getClass(), decoders, null)); + setPathParameters(Collections.<String,String>emptyMap()); + } + + @Override + public void onOpen(Session session, EndpointConfig config) { + doOnOpen(session, config); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java b/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java new file mode 100644 index 00000000..499f8274 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoEndpointServer.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.util.Map; + +import javax.websocket.EndpointConfig; +import javax.websocket.Session; +import javax.websocket.server.ServerEndpointConfig; + +import org.apache.tomcat.util.res.StringManager; + +/** + * Wrapper class for instances of POJOs annotated with + * {@link javax.websocket.server.ServerEndpoint} so they appear as standard + * {@link javax.websocket.Endpoint} instances. + */ +public class PojoEndpointServer extends PojoEndpointBase { + + private static final StringManager sm = + StringManager.getManager(PojoEndpointServer.class); + + @Override + public void onOpen(Session session, EndpointConfig endpointConfig) { + + ServerEndpointConfig sec = (ServerEndpointConfig) endpointConfig; + + Object pojo; + try { + pojo = sec.getConfigurator().getEndpointInstance( + sec.getEndpointClass()); + } catch (InstantiationException e) { + throw new IllegalArgumentException(sm.getString( + "pojoEndpointServer.getPojoInstanceFail", + sec.getEndpointClass().getName()), e); + } + setPojo(pojo); + + @SuppressWarnings("unchecked") + Map<String,String> pathParameters = + (Map<String, String>) sec.getUserProperties().get( + Constants.POJO_PATH_PARAM_KEY); + setPathParameters(pathParameters); + + PojoMethodMapping methodMapping = + (PojoMethodMapping) sec.getUserProperties().get( + Constants.POJO_METHOD_MAPPING_KEY); + setMethodMapping(methodMapping); + + doOnOpen(session, endpointConfig); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java new file mode 100644 index 00000000..b72d719a --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerBase.java @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.EncodeException; +import javax.websocket.MessageHandler; +import javax.websocket.RemoteEndpoint; +import javax.websocket.Session; + +import org.apache.tomcat.util.ExceptionUtils; +import nginx.unit.websocket.WrappedMessageHandler; + +/** + * Common implementation code for the POJO message handlers. + * + * @param <T> The type of message to handle + */ +public abstract class PojoMessageHandlerBase<T> + implements WrappedMessageHandler { + + protected final Object pojo; + protected final Method method; + protected final Session session; + protected final Object[] params; + protected final int indexPayload; + protected final boolean convert; + protected final int indexSession; + protected final long maxMessageSize; + + public PojoMessageHandlerBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexSession, long maxMessageSize) { + this.pojo = pojo; + this.method = method; + // TODO: The method should already be accessible here but the following + // code seems to be necessary in some as yet not fully understood cases. + try { + this.method.setAccessible(true); + } catch (Exception e) { + // It is better to make sure the method is accessible, but + // ignore exceptions and hope for the best + } + this.session = session; + this.params = params; + this.indexPayload = indexPayload; + this.convert = convert; + this.indexSession = indexSession; + this.maxMessageSize = maxMessageSize; + } + + + protected final void processResult(Object result) { + if (result == null) { + return; + } + + RemoteEndpoint.Basic remoteEndpoint = session.getBasicRemote(); + try { + if (result instanceof String) { + remoteEndpoint.sendText((String) result); + } else if (result instanceof ByteBuffer) { + remoteEndpoint.sendBinary((ByteBuffer) result); + } else if (result instanceof byte[]) { + remoteEndpoint.sendBinary(ByteBuffer.wrap((byte[]) result)); + } else { + remoteEndpoint.sendObject(result); + } + } catch (IOException | EncodeException ioe) { + throw new IllegalStateException(ioe); + } + } + + + /** + * Expose the POJO if it is a message handler so the Session is able to + * match requests to remove handlers if the original handler has been + * wrapped. + */ + @Override + public final MessageHandler getWrappedHandler() { + if (pojo instanceof MessageHandler) { + return (MessageHandler) pojo; + } else { + return null; + } + } + + + @Override + public final long getMaxMessageSize() { + return maxMessageSize; + } + + + protected final void handlePojoMethodException(Throwable t) { + t = ExceptionUtils.unwrapInvocationTargetException(t); + ExceptionUtils.handleThrowable(t); + if (t instanceof RuntimeException) { + throw (RuntimeException) t; + } else { + throw new RuntimeException(t.getMessage(), t); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java new file mode 100644 index 00000000..d6f37724 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBase.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.DecodeException; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import nginx.unit.websocket.WsSession; + +/** + * Common implementation code for the POJO partial message handlers. All + * the real work is done in this class and in the superclass. + * + * @param <T> The type of message to handle + */ +public abstract class PojoMessageHandlerPartialBase<T> + extends PojoMessageHandlerBase<T> implements MessageHandler.Partial<T> { + + private final int indexBoolean; + + public PojoMessageHandlerPartialBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, + boolean convert, int indexBoolean, int indexSession, + long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + this.indexBoolean = indexBoolean; + } + + + @Override + public final void onMessage(T message, boolean last) { + if (params.length == 1 && params[0] instanceof DecodeException) { + ((WsSession) session).getLocal().onError(session, + (DecodeException) params[0]); + return; + } + Object[] parameters = params.clone(); + if (indexBoolean != -1) { + parameters[indexBoolean] = Boolean.valueOf(last); + } + if (indexSession != -1) { + parameters[indexSession] = session; + } + if (convert) { + parameters[indexPayload] = ((ByteBuffer) message).array(); + } else { + parameters[indexPayload] = message; + } + Object result = null; + try { + result = method.invoke(pojo, parameters); + } catch (IllegalAccessException | InvocationTargetException e) { + handlePojoMethodException(e); + } + processResult(result); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java new file mode 100644 index 00000000..1d334017 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialBinary.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +import javax.websocket.Session; + +/** + * ByteBuffer specific concrete implementation for handling partial messages. + */ +public class PojoMessageHandlerPartialBinary + extends PojoMessageHandlerPartialBase<ByteBuffer> { + + public PojoMessageHandlerPartialBinary(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexBoolean, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, indexBoolean, + indexSession, maxMessageSize); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java new file mode 100644 index 00000000..8f7c1a0d --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerPartialText.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; + +import javax.websocket.Session; + +/** + * Text specific concrete implementation for handling partial messages. + */ +public class PojoMessageHandlerPartialText + extends PojoMessageHandlerPartialBase<String> { + + public PojoMessageHandlerPartialText(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexBoolean, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, indexBoolean, + indexSession, maxMessageSize); + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java new file mode 100644 index 00000000..23333eb7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBase.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +import javax.websocket.DecodeException; +import javax.websocket.MessageHandler; +import javax.websocket.Session; + +import nginx.unit.websocket.WsSession; + +/** + * Common implementation code for the POJO whole message handlers. All the real + * work is done in this class and in the superclass. + * + * @param <T> The type of message to handle + */ +public abstract class PojoMessageHandlerWholeBase<T> + extends PojoMessageHandlerBase<T> implements MessageHandler.Whole<T> { + + public PojoMessageHandlerWholeBase(Object pojo, Method method, + Session session, Object[] params, int indexPayload, + boolean convert, int indexSession, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + } + + + @Override + public final void onMessage(T message) { + + if (params.length == 1 && params[0] instanceof DecodeException) { + ((WsSession) session).getLocal().onError(session, + (DecodeException) params[0]); + return; + } + + // Can this message be decoded? + Object payload; + try { + payload = decode(message); + } catch (DecodeException de) { + ((WsSession) session).getLocal().onError(session, de); + return; + } + + if (payload == null) { + // Not decoded. Convert if required. + if (convert) { + payload = convert(message); + } else { + payload = message; + } + } + + Object[] parameters = params.clone(); + if (indexSession != -1) { + parameters[indexSession] = session; + } + parameters[indexPayload] = payload; + + Object result = null; + try { + result = method.invoke(pojo, parameters); + } catch (IllegalAccessException | InvocationTargetException e) { + handlePojoMethodException(e); + } + processResult(result); + } + + protected Object convert(T message) { + return message; + } + + + protected abstract Object decode(T message) throws DecodeException; + protected abstract void onClose(); +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java new file mode 100644 index 00000000..07ff0648 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeBinary.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Binary; +import javax.websocket.Decoder.BinaryStream; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; + +/** + * ByteBuffer specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholeBinary + extends PojoMessageHandlerWholeBase<ByteBuffer> { + + private static final StringManager sm = + StringManager.getManager(PojoMessageHandlerWholeBinary.class); + + private final List<Decoder> decoders = new ArrayList<>(); + + private final boolean isForInputStream; + + public PojoMessageHandlerWholeBinary(Object pojo, Method method, + Session session, EndpointConfig config, + List<Class<? extends Decoder>> decoderClazzes, Object[] params, + int indexPayload, boolean convert, int indexSession, + boolean isForInputStream, long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + + // Update binary text size handled by session + if (maxMessageSize > -1 && maxMessageSize > session.getMaxBinaryMessageBufferSize()) { + if (maxMessageSize > Integer.MAX_VALUE) { + throw new IllegalArgumentException(sm.getString( + "pojoMessageHandlerWhole.maxBufferSize")); + } + session.setMaxBinaryMessageBufferSize((int) maxMessageSize); + } + + try { + if (decoderClazzes != null) { + for (Class<? extends Decoder> decoderClazz : decoderClazzes) { + if (Binary.class.isAssignableFrom(decoderClazz)) { + Binary<?> decoder = (Binary<?>) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else if (BinaryStream.class.isAssignableFrom( + decoderClazz)) { + BinaryStream<?> decoder = (BinaryStream<?>) + decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else { + // Text decoder - ignore it + } + } + } + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException(e); + } + this.isForInputStream = isForInputStream; + } + + + @Override + protected Object decode(ByteBuffer message) throws DecodeException { + for (Decoder decoder : decoders) { + if (decoder instanceof Binary) { + if (((Binary<?>) decoder).willDecode(message)) { + return ((Binary<?>) decoder).decode(message); + } + } else { + byte[] array = new byte[message.limit() - message.position()]; + message.get(array); + ByteArrayInputStream bais = new ByteArrayInputStream(array); + try { + return ((BinaryStream<?>) decoder).decode(bais); + } catch (IOException ioe) { + throw new DecodeException(message, sm.getString( + "pojoMessageHandlerWhole.decodeIoFail"), ioe); + } + } + } + return null; + } + + + @Override + protected Object convert(ByteBuffer message) { + byte[] array = new byte[message.remaining()]; + message.get(array); + if (isForInputStream) { + return new ByteArrayInputStream(array); + } else { + return array; + } + } + + + @Override + protected void onClose() { + for (Decoder decoder : decoders) { + decoder.destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java new file mode 100644 index 00000000..bdedd7de --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholePong.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.lang.reflect.Method; + +import javax.websocket.PongMessage; +import javax.websocket.Session; + +/** + * PongMessage specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholePong + extends PojoMessageHandlerWholeBase<PongMessage> { + + public PojoMessageHandlerWholePong(Object pojo, Method method, + Session session, Object[] params, int indexPayload, boolean convert, + int indexSession) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, -1); + } + + @Override + protected Object decode(PongMessage message) { + // Never decoded + return null; + } + + + @Override + protected void onClose() { + // NO-OP + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java new file mode 100644 index 00000000..59007349 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMessageHandlerWholeText.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.IOException; +import java.io.StringReader; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; + +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Text; +import javax.websocket.Decoder.TextStream; +import javax.websocket.EndpointConfig; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.Util; + + +/** + * Text specific concrete implementation for handling whole messages. + */ +public class PojoMessageHandlerWholeText + extends PojoMessageHandlerWholeBase<String> { + + private static final StringManager sm = + StringManager.getManager(PojoMessageHandlerWholeText.class); + + private final List<Decoder> decoders = new ArrayList<>(); + private final Class<?> primitiveType; + + public PojoMessageHandlerWholeText(Object pojo, Method method, + Session session, EndpointConfig config, + List<Class<? extends Decoder>> decoderClazzes, Object[] params, + int indexPayload, boolean convert, int indexSession, + long maxMessageSize) { + super(pojo, method, session, params, indexPayload, convert, + indexSession, maxMessageSize); + + // Update max text size handled by session + if (maxMessageSize > -1 && maxMessageSize > session.getMaxTextMessageBufferSize()) { + if (maxMessageSize > Integer.MAX_VALUE) { + throw new IllegalArgumentException(sm.getString( + "pojoMessageHandlerWhole.maxBufferSize")); + } + session.setMaxTextMessageBufferSize((int) maxMessageSize); + } + + // Check for primitives + Class<?> type = method.getParameterTypes()[indexPayload]; + if (Util.isPrimitive(type)) { + primitiveType = type; + return; + } else { + primitiveType = null; + } + + try { + if (decoderClazzes != null) { + for (Class<? extends Decoder> decoderClazz : decoderClazzes) { + if (Text.class.isAssignableFrom(decoderClazz)) { + Text<?> decoder = (Text<?>) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else if (TextStream.class.isAssignableFrom( + decoderClazz)) { + TextStream<?> decoder = + (TextStream<?>) decoderClazz.getConstructor().newInstance(); + decoder.init(config); + decoders.add(decoder); + } else { + // Binary decoder - ignore it + } + } + } + } catch (ReflectiveOperationException e) { + throw new IllegalArgumentException(e); + } + } + + + @Override + protected Object decode(String message) throws DecodeException { + // Handle primitives + if (primitiveType != null) { + return Util.coerceToType(primitiveType, message); + } + // Handle full decoders + for (Decoder decoder : decoders) { + if (decoder instanceof Text) { + if (((Text<?>) decoder).willDecode(message)) { + return ((Text<?>) decoder).decode(message); + } + } else { + StringReader r = new StringReader(message); + try { + return ((TextStream<?>) decoder).decode(r); + } catch (IOException ioe) { + throw new DecodeException(message, sm.getString( + "pojoMessageHandlerWhole.decodeIoFail"), ioe); + } + } + } + return null; + } + + + @Override + protected Object convert(String message) { + return new StringReader(message); + } + + + @Override + protected void onClose() { + for (Decoder decoder : decoders) { + decoder.destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java new file mode 100644 index 00000000..2385b5c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java @@ -0,0 +1,731 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.InputStream; +import java.io.Reader; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.websocket.CloseReason; +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.DeploymentException; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.OnClose; +import javax.websocket.OnError; +import javax.websocket.OnMessage; +import javax.websocket.OnOpen; +import javax.websocket.PongMessage; +import javax.websocket.Session; +import javax.websocket.server.PathParam; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.DecoderEntry; +import nginx.unit.websocket.Util; +import nginx.unit.websocket.Util.DecoderMatch; + +/** + * For a POJO class annotated with + * {@link javax.websocket.server.ServerEndpoint}, an instance of this class + * creates and caches the method handler, method information and parameter + * information for the onXXX calls. + */ +public class PojoMethodMapping { + + private static final StringManager sm = + StringManager.getManager(PojoMethodMapping.class); + + private final Method onOpen; + private final Method onClose; + private final Method onError; + private final PojoPathParam[] onOpenParams; + private final PojoPathParam[] onCloseParams; + private final PojoPathParam[] onErrorParams; + private final List<MessageHandlerInfo> onMessage = new ArrayList<>(); + private final String wsPath; + + + public PojoMethodMapping(Class<?> clazzPojo, + List<Class<? extends Decoder>> decoderClazzes, String wsPath) + throws DeploymentException { + + this.wsPath = wsPath; + + List<DecoderEntry> decoders = Util.getDecoders(decoderClazzes); + Method open = null; + Method close = null; + Method error = null; + Method[] clazzPojoMethods = null; + Class<?> currentClazz = clazzPojo; + while (!currentClazz.equals(Object.class)) { + Method[] currentClazzMethods = currentClazz.getDeclaredMethods(); + if (currentClazz == clazzPojo) { + clazzPojoMethods = currentClazzMethods; + } + for (Method method : currentClazzMethods) { + if (method.getAnnotation(OnOpen.class) != null) { + checkPublic(method); + if (open == null) { + open = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(open, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnOpen.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnClose.class) != null) { + checkPublic(method); + if (close == null) { + close = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(close, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnClose.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnError.class) != null) { + checkPublic(method); + if (error == null) { + error = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(error, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnError.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnMessage.class) != null) { + checkPublic(method); + MessageHandlerInfo messageHandler = new MessageHandlerInfo(method, decoders); + boolean found = false; + for (MessageHandlerInfo otherMessageHandler : onMessage) { + if (messageHandler.targetsSameWebSocketMessageType(otherMessageHandler)) { + found = true; + if (currentClazz == clazzPojo || + !isMethodOverride(messageHandler.m, otherMessageHandler.m)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnMessage.class, currentClazz)); + } + } + } + if (!found) { + onMessage.add(messageHandler); + } + } else { + // Method not annotated + } + } + currentClazz = currentClazz.getSuperclass(); + } + // If the methods are not on clazzPojo and they are overridden + // by a non annotated method in clazzPojo, they should be ignored + if (open != null && open.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, open, OnOpen.class)) { + open = null; + } + } + if (close != null && close.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, close, OnClose.class)) { + close = null; + } + } + if (error != null && error.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, error, OnError.class)) { + error = null; + } + } + List<MessageHandlerInfo> overriddenOnMessage = new ArrayList<>(); + for (MessageHandlerInfo messageHandler : onMessage) { + if (messageHandler.m.getDeclaringClass() != clazzPojo + && isOverridenWithoutAnnotation(clazzPojoMethods, messageHandler.m, OnMessage.class)) { + overriddenOnMessage.add(messageHandler); + } + } + for (MessageHandlerInfo messageHandler : overriddenOnMessage) { + onMessage.remove(messageHandler); + } + this.onOpen = open; + this.onClose = close; + this.onError = error; + onOpenParams = getPathParams(onOpen, MethodType.ON_OPEN); + onCloseParams = getPathParams(onClose, MethodType.ON_CLOSE); + onErrorParams = getPathParams(onError, MethodType.ON_ERROR); + } + + + private void checkPublic(Method m) throws DeploymentException { + if (!Modifier.isPublic(m.getModifiers())) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.methodNotPublic", m.getName())); + } + } + + + private boolean isMethodOverride(Method method1, Method method2) { + return method1.getName().equals(method2.getName()) + && method1.getReturnType().equals(method2.getReturnType()) + && Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes()); + } + + + private boolean isOverridenWithoutAnnotation(Method[] methods, + Method superclazzMethod, Class<? extends Annotation> annotation) { + for (Method method : methods) { + if (isMethodOverride(method, superclazzMethod) + && (method.getAnnotation(annotation) == null)) { + return true; + } + } + return false; + } + + + public String getWsPath() { + return wsPath; + } + + + public Method getOnOpen() { + return onOpen; + } + + + public Object[] getOnOpenArgs(Map<String,String> pathParameters, + Session session, EndpointConfig config) throws DecodeException { + return buildArgs(onOpenParams, pathParameters, session, config, null, + null); + } + + + public Method getOnClose() { + return onClose; + } + + + public Object[] getOnCloseArgs(Map<String,String> pathParameters, + Session session, CloseReason closeReason) throws DecodeException { + return buildArgs(onCloseParams, pathParameters, session, null, null, + closeReason); + } + + + public Method getOnError() { + return onError; + } + + + public Object[] getOnErrorArgs(Map<String,String> pathParameters, + Session session, Throwable throwable) throws DecodeException { + return buildArgs(onErrorParams, pathParameters, session, null, + throwable, null); + } + + + public boolean hasMessageHandlers() { + return !onMessage.isEmpty(); + } + + + public Set<MessageHandler> getMessageHandlers(Object pojo, + Map<String,String> pathParameters, Session session, + EndpointConfig config) { + Set<MessageHandler> result = new HashSet<>(); + for (MessageHandlerInfo messageMethod : onMessage) { + result.addAll(messageMethod.getMessageHandlers(pojo, pathParameters, + session, config)); + } + return result; + } + + + private static PojoPathParam[] getPathParams(Method m, + MethodType methodType) throws DeploymentException { + if (m == null) { + return new PojoPathParam[0]; + } + boolean foundThrowable = false; + Class<?>[] types = m.getParameterTypes(); + Annotation[][] paramsAnnotations = m.getParameterAnnotations(); + PojoPathParam[] result = new PojoPathParam[types.length]; + for (int i = 0; i < types.length; i++) { + Class<?> type = types[i]; + if (type.equals(Session.class)) { + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_OPEN && + type.equals(EndpointConfig.class)) { + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_ERROR + && type.equals(Throwable.class)) { + foundThrowable = true; + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_CLOSE && + type.equals(CloseReason.class)) { + result[i] = new PojoPathParam(type, null); + } else { + Annotation[] paramAnnotations = paramsAnnotations[i]; + for (Annotation paramAnnotation : paramAnnotations) { + if (paramAnnotation.annotationType().equals( + PathParam.class)) { + // Check that the type is valid. "0" coerces to every + // valid type + try { + Util.coerceToType(type, "0"); + } catch (IllegalArgumentException iae) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.invalidPathParamType"), + iae); + } + result[i] = new PojoPathParam(type, + ((PathParam) paramAnnotation).value()); + break; + } + } + // Parameters without annotations are not permitted + if (result[i] == null) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.paramWithoutAnnotation", + type, m.getName(), m.getClass().getName())); + } + } + } + if (methodType == MethodType.ON_ERROR && !foundThrowable) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.onErrorNoThrowable", + m.getName(), m.getDeclaringClass().getName())); + } + return result; + } + + + private static Object[] buildArgs(PojoPathParam[] pathParams, + Map<String,String> pathParameters, Session session, + EndpointConfig config, Throwable throwable, CloseReason closeReason) + throws DecodeException { + Object[] result = new Object[pathParams.length]; + for (int i = 0; i < pathParams.length; i++) { + Class<?> type = pathParams[i].getType(); + if (type.equals(Session.class)) { + result[i] = session; + } else if (type.equals(EndpointConfig.class)) { + result[i] = config; + } else if (type.equals(Throwable.class)) { + result[i] = throwable; + } else if (type.equals(CloseReason.class)) { + result[i] = closeReason; + } else { + String name = pathParams[i].getName(); + String value = pathParameters.get(name); + try { + result[i] = Util.coerceToType(type, value); + } catch (Exception e) { + throw new DecodeException(value, sm.getString( + "pojoMethodMapping.decodePathParamFail", + value, type), e); + } + } + } + return result; + } + + + private static class MessageHandlerInfo { + + private final Method m; + private int indexString = -1; + private int indexByteArray = -1; + private int indexByteBuffer = -1; + private int indexPong = -1; + private int indexBoolean = -1; + private int indexSession = -1; + private int indexInputStream = -1; + private int indexReader = -1; + private int indexPrimitive = -1; + private Class<?> primitiveType = null; + private Map<Integer,PojoPathParam> indexPathParams = new HashMap<>(); + private int indexPayload = -1; + private DecoderMatch decoderMatch = null; + private long maxMessageSize = -1; + + public MessageHandlerInfo(Method m, List<DecoderEntry> decoderEntries) { + this.m = m; + + Class<?>[] types = m.getParameterTypes(); + Annotation[][] paramsAnnotations = m.getParameterAnnotations(); + + for (int i = 0; i < types.length; i++) { + boolean paramFound = false; + Annotation[] paramAnnotations = paramsAnnotations[i]; + for (Annotation paramAnnotation : paramAnnotations) { + if (paramAnnotation.annotationType().equals( + PathParam.class)) { + indexPathParams.put( + Integer.valueOf(i), new PojoPathParam(types[i], + ((PathParam) paramAnnotation).value())); + paramFound = true; + break; + } + } + if (paramFound) { + continue; + } + if (String.class.isAssignableFrom(types[i])) { + if (indexString == -1) { + indexString = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Reader.class.isAssignableFrom(types[i])) { + if (indexReader == -1) { + indexReader = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (boolean.class == types[i]) { + if (indexBoolean == -1) { + indexBoolean = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateLastParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (ByteBuffer.class.isAssignableFrom(types[i])) { + if (indexByteBuffer == -1) { + indexByteBuffer = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (byte[].class == types[i]) { + if (indexByteArray == -1) { + indexByteArray = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (InputStream.class.isAssignableFrom(types[i])) { + if (indexInputStream == -1) { + indexInputStream = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Util.isPrimitive(types[i])) { + if (indexPrimitive == -1) { + indexPrimitive = i; + primitiveType = types[i]; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Session.class.isAssignableFrom(types[i])) { + if (indexSession == -1) { + indexSession = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateSessionParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (PongMessage.class.isAssignableFrom(types[i])) { + if (indexPong == -1) { + indexPong = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicatePongMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else { + if (decoderMatch != null && decoderMatch.hasMatches()) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + decoderMatch = new DecoderMatch(types[i], decoderEntries); + + if (decoderMatch.hasMatches()) { + indexPayload = i; + } + } + } + + // Additional checks required + if (indexString != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexString; + } + } + if (indexReader != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexReader; + } + } + if (indexByteArray != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexByteArray; + } + } + if (indexByteBuffer != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexByteBuffer; + } + } + if (indexInputStream != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexInputStream; + } + } + if (indexPrimitive != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexPrimitive; + } + } + if (indexPong != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.pongWithPayload", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexPong; + } + } + if (indexPayload == -1 && indexPrimitive == -1 && + indexBoolean != -1) { + // The boolean we found is a payload, not a last flag + indexPayload = indexBoolean; + indexPrimitive = indexBoolean; + primitiveType = Boolean.TYPE; + indexBoolean = -1; + } + if (indexPayload == -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.noPayload", + m.getName(), m.getDeclaringClass().getName())); + } + if (indexPong != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialPong", + m.getName(), m.getDeclaringClass().getName())); + } + if(indexReader != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialReader", + m.getName(), m.getDeclaringClass().getName())); + } + if(indexInputStream != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialInputStream", + m.getName(), m.getDeclaringClass().getName())); + } + if (decoderMatch != null && decoderMatch.hasMatches() && + indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialObject", + m.getName(), m.getDeclaringClass().getName())); + } + + maxMessageSize = m.getAnnotation(OnMessage.class).maxMessageSize(); + } + + + public boolean targetsSameWebSocketMessageType(MessageHandlerInfo otherHandler) { + if (otherHandler == null) { + return false; + } + if (indexByteArray >= 0 && otherHandler.indexByteArray >= 0) { + return true; + } + if (indexByteBuffer >= 0 && otherHandler.indexByteBuffer >= 0) { + return true; + } + if (indexInputStream >= 0 && otherHandler.indexInputStream >= 0) { + return true; + } + if (indexPong >= 0 && otherHandler.indexPong >= 0) { + return true; + } + if (indexPrimitive >= 0 && otherHandler.indexPrimitive >= 0 + && primitiveType == otherHandler.primitiveType) { + return true; + } + if (indexReader >= 0 && otherHandler.indexReader >= 0) { + return true; + } + if (indexString >= 0 && otherHandler.indexString >= 0) { + return true; + } + if (decoderMatch != null && otherHandler.decoderMatch != null + && decoderMatch.getTarget().equals(otherHandler.decoderMatch.getTarget())) { + return true; + } + return false; + } + + + public Set<MessageHandler> getMessageHandlers(Object pojo, + Map<String,String> pathParameters, Session session, + EndpointConfig config) { + Object[] params = new Object[m.getParameterTypes().length]; + + for (Map.Entry<Integer,PojoPathParam> entry : + indexPathParams.entrySet()) { + PojoPathParam pathParam = entry.getValue(); + String valueString = pathParameters.get(pathParam.getName()); + Object value = null; + try { + value = Util.coerceToType(pathParam.getType(), valueString); + } catch (Exception e) { + DecodeException de = new DecodeException(valueString, + sm.getString( + "pojoMethodMapping.decodePathParamFail", + valueString, pathParam.getType()), e); + params = new Object[] { de }; + break; + } + params[entry.getKey().intValue()] = value; + } + + Set<MessageHandler> results = new HashSet<>(2); + if (indexBoolean == -1) { + // Basic + if (indexString != -1 || indexPrimitive != -1) { + MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, + session, config, null, params, indexPayload, false, + indexSession, maxMessageSize); + results.add(mh); + } else if (indexReader != -1) { + MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, + session, config, null, params, indexReader, true, + indexSession, maxMessageSize); + results.add(mh); + } else if (indexByteArray != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexByteArray, + true, indexSession, false, maxMessageSize); + results.add(mh); + } else if (indexByteBuffer != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexByteBuffer, + false, indexSession, false, maxMessageSize); + results.add(mh); + } else if (indexInputStream != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexInputStream, + true, indexSession, true, maxMessageSize); + results.add(mh); + } else if (decoderMatch != null && decoderMatch.hasMatches()) { + if (decoderMatch.getBinaryDecoders().size() > 0) { + MessageHandler mh = new PojoMessageHandlerWholeBinary( + pojo, m, session, config, + decoderMatch.getBinaryDecoders(), params, + indexPayload, true, indexSession, true, + maxMessageSize); + results.add(mh); + } + if (decoderMatch.getTextDecoders().size() > 0) { + MessageHandler mh = new PojoMessageHandlerWholeText( + pojo, m, session, config, + decoderMatch.getTextDecoders(), params, + indexPayload, true, indexSession, maxMessageSize); + results.add(mh); + } + } else { + MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m, + session, params, indexPong, false, indexSession); + results.add(mh); + } + } else { + // ASync + if (indexString != -1) { + MessageHandler mh = new PojoMessageHandlerPartialText(pojo, + m, session, params, indexString, false, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } else if (indexByteArray != -1) { + MessageHandler mh = new PojoMessageHandlerPartialBinary( + pojo, m, session, params, indexByteArray, true, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } else { + MessageHandler mh = new PojoMessageHandlerPartialBinary( + pojo, m, session, params, indexByteBuffer, false, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } + } + return results; + } + } + + + private enum MethodType { + ON_OPEN, + ON_CLOSE, + ON_ERROR + } +} diff --git a/src/java/nginx/unit/websocket/pojo/PojoPathParam.java b/src/java/nginx/unit/websocket/pojo/PojoPathParam.java new file mode 100644 index 00000000..859b6d68 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoPathParam.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +/** + * Stores the parameter type and name for a parameter that needs to be passed to + * an onXxx method of {@link javax.websocket.Endpoint}. The name is only present + * for parameters annotated with + * {@link javax.websocket.server.PathParam}. For the + * {@link javax.websocket.Session} and {@link java.lang.Throwable} parameters, + * {@link #getName()} will always return <code>null</code>. + */ +public class PojoPathParam { + + private final Class<?> type; + private final String name; + + + public PojoPathParam(Class<?> type, String name) { + this.type = type; + this.name = name; + } + + + public Class<?> getType() { + return type; + } + + + public String getName() { + return name; + } +} diff --git a/src/java/nginx/unit/websocket/pojo/package-info.java b/src/java/nginx/unit/websocket/pojo/package-info.java new file mode 100644 index 00000000..39cf80c8 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package provides the necessary plumbing to convert an annotated POJO + * into a WebSocket {@link javax.websocket.Endpoint}. + */ +package nginx.unit.websocket.pojo; diff --git a/src/java/nginx/unit/websocket/server/Constants.java b/src/java/nginx/unit/websocket/server/Constants.java new file mode 100644 index 00000000..5210c4ba --- /dev/null +++ b/src/java/nginx/unit/websocket/server/Constants.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +/** + * Internal implementation constants. + */ +public class Constants { + + public static final String BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.binaryBufferSize"; + public static final String TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.textBufferSize"; + public static final String ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM = + "nginx.unit.websocket.noAddAfterHandshake"; + + public static final String SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE = + "javax.websocket.server.ServerContainer"; + + + private Constants() { + // Hide default constructor + } +} diff --git a/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java b/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java new file mode 100644 index 00000000..43ffe2bc --- /dev/null +++ b/src/java/nginx/unit/websocket/server/DefaultServerEndpointConfigurator.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.server.HandshakeRequest; +import javax.websocket.server.ServerEndpointConfig; + +public class DefaultServerEndpointConfigurator + extends ServerEndpointConfig.Configurator { + + @Override + public <T> T getEndpointInstance(Class<T> clazz) + throws InstantiationException { + try { + return clazz.getConstructor().newInstance(); + } catch (InstantiationException e) { + throw e; + } catch (ReflectiveOperationException e) { + InstantiationException ie = new InstantiationException(); + ie.initCause(e); + throw ie; + } + } + + + @Override + public String getNegotiatedSubprotocol(List<String> supported, + List<String> requested) { + + for (String request : requested) { + if (supported.contains(request)) { + return request; + } + } + return ""; + } + + + @Override + public List<Extension> getNegotiatedExtensions(List<Extension> installed, + List<Extension> requested) { + Set<String> installedNames = new HashSet<>(); + for (Extension e : installed) { + installedNames.add(e.getName()); + } + List<Extension> result = new ArrayList<>(); + for (Extension request : requested) { + if (installedNames.contains(request.getName())) { + result.add(request); + } + } + return result; + } + + + @Override + public boolean checkOrigin(String originHeaderValue) { + return true; + } + + @Override + public void modifyHandshake(ServerEndpointConfig sec, + HandshakeRequest request, HandshakeResponse response) { + // NO-OP + } + +} diff --git a/src/java/nginx/unit/websocket/server/LocalStrings.properties b/src/java/nginx/unit/websocket/server/LocalStrings.properties new file mode 100644 index 00000000..5bc12501 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/LocalStrings.properties @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +serverContainer.addNotAllowed=No further Endpoints may be registered once an attempt has been made to use one of the previously registered endpoints +serverContainer.configuratorFail=Failed to create configurator of type [{0}] for POJO of type [{1}] +serverContainer.duplicatePaths=Multiple Endpoints may not be deployed to the same path [{0}] : existing endpoint was [{1}] and new endpoint is [{2}] +serverContainer.encoderFail=Unable to create encoder of type [{0}] +serverContainer.endpointDeploy=Endpoint class [{0}] deploying to path [{1}] in ServletContext [{2}] +serverContainer.missingAnnotation=Cannot deploy POJO class [{0}] as it is not annotated with @ServerEndpoint +serverContainer.missingEndpoint=An Endpoint instance has been request for path [{0}] but no matching Endpoint class was found +serverContainer.pojoDeploy=POJO class [{0}] deploying to path [{1}] in ServletContext [{2}] +serverContainer.servletContextMismatch=Attempted to register a POJO annotated for WebSocket at path [{0}] in the ServletContext with context path [{1}] when the WebSocket ServerContainer is allocated to the ServletContext with context path [{2}] +serverContainer.servletContextMissing=No ServletContext was specified + +upgradeUtil.incompatibleRsv=Extensions were specified that have incompatible RSV bit usage + +uriTemplate.duplicateParameter=The parameter [{0}] appears more than once in the path which is not permitted +uriTemplate.emptySegment=The path [{0}] contains one or more empty segments which are is not permitted +uriTemplate.invalidPath=The path [{0}] is not valid. +uriTemplate.invalidSegment=The segment [{0}] is not valid in the provided path [{1}] + +wsFrameServer.bytesRead=Read [{0}] bytes into input buffer ready for processing +wsFrameServer.illegalReadState=Unexpected read state [{0}] +wsFrameServer.onDataAvailable=Method entry + +wsHttpUpgradeHandler.closeOnError=Closing WebSocket connection due to an error +wsHttpUpgradeHandler.destroyFailed=Failed to close WebConnection while destroying the WebSocket HttpUpgradeHandler +wsHttpUpgradeHandler.noPreInit=The preInit() method must be called to configure the WebSocket HttpUpgradeHandler before the container calls init(). Usually, this means the Servlet that created the WsHttpUpgradeHandler instance should also call preInit() +wsHttpUpgradeHandler.serverStop=The server is stopping + +wsRemoteEndpointServer.closeFailed=Failed to close the ServletOutputStream connection cleanly diff --git a/src/java/nginx/unit/websocket/server/UpgradeUtil.java b/src/java/nginx/unit/websocket/server/UpgradeUtil.java new file mode 100644 index 00000000..162f01c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/UpgradeUtil.java @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.server.ServerEndpointConfig; + +import nginx.unit.Request; + +import org.apache.tomcat.util.codec.binary.Base64; +import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.security.ConcurrentMessageDigest; +import nginx.unit.websocket.Constants; +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.TransformationFactory; +import nginx.unit.websocket.Util; +import nginx.unit.websocket.WsHandshakeResponse; +import nginx.unit.websocket.pojo.PojoEndpointServer; + +public class UpgradeUtil { + + private static final StringManager sm = + StringManager.getManager(UpgradeUtil.class.getPackage().getName()); + private static final byte[] WS_ACCEPT = + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes( + StandardCharsets.ISO_8859_1); + + private UpgradeUtil() { + // Utility class. Hide default constructor. + } + + /** + * Checks to see if this is an HTTP request that includes a valid upgrade + * request to web socket. + * <p> + * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java + * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC + * 6455 section 4.1 requires that a WebSocket Upgrade uses GET. + * @param request The request to check if it is an HTTP upgrade request for + * a WebSocket connection + * @param response The response associated with the request + * @return <code>true</code> if the request includes a HTTP Upgrade request + * for the WebSocket protocol, otherwise <code>false</code> + */ + public static boolean isWebSocketUpgradeRequest(ServletRequest request, + ServletResponse response) { + + Request r = (Request) request.getAttribute(Request.BARE); + + return ((request instanceof HttpServletRequest) && + (response instanceof HttpServletResponse) && + (r != null) && + (r.isUpgrade())); + } + + + public static void doUpgrade(WsServerContainer sc, HttpServletRequest req, + HttpServletResponse resp, ServerEndpointConfig sec, + Map<String,String> pathParams) + throws ServletException, IOException { + + + // Origin check + String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME); + + if (!sec.getConfigurator().checkOrigin(origin)) { + resp.sendError(HttpServletResponse.SC_FORBIDDEN); + return; + } + // Sub-protocols + List<String> subProtocols = getTokensFromHeader(req, + Constants.WS_PROTOCOL_HEADER_NAME); + String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol( + sec.getSubprotocols(), subProtocols); + + // Extensions + // Should normally only be one header but handle the case of multiple + // headers + List<Extension> extensionsRequested = new ArrayList<>(); + Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME); + while (extHeaders.hasMoreElements()) { + Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement()); + } + + // Negotiation phase 1. By default this simply filters out the + // extensions that the server does not support but applications could + // use a custom configurator to do more than this. + List<Extension> installedExtensions = null; + if (sec.getExtensions().size() == 0) { + installedExtensions = Constants.INSTALLED_EXTENSIONS; + } else { + installedExtensions = new ArrayList<>(); + installedExtensions.addAll(sec.getExtensions()); + installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS); + } + List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions( + installedExtensions, extensionsRequested); + + // Negotiation phase 2. Create the Transformations that will be applied + // to this connection. Note than an extension may be dropped at this + // point if the client has requested a configuration that the server is + // unable to support. + List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1); + + List<Extension> negotiatedExtensionsPhase2; + if (transformations.isEmpty()) { + negotiatedExtensionsPhase2 = Collections.emptyList(); + } else { + negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size()); + for (Transformation t : transformations) { + negotiatedExtensionsPhase2.add(t.getExtensionResponse()); + } + } + + WsHttpUpgradeHandler wsHandler = + req.upgrade(WsHttpUpgradeHandler.class); + + WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams); + WsHandshakeResponse wsResponse = new WsHandshakeResponse(); + WsPerSessionServerEndpointConfig perSessionServerEndpointConfig = + new WsPerSessionServerEndpointConfig(sec); + sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig, + wsRequest, wsResponse); + //wsRequest.finished(); + + // Add any additional headers + for (Entry<String,List<String>> entry : + wsResponse.getHeaders().entrySet()) { + for (String headerValue: entry.getValue()) { + resp.addHeader(entry.getKey(), headerValue); + } + } + + Endpoint ep; + try { + Class<?> clazz = sec.getEndpointClass(); + if (Endpoint.class.isAssignableFrom(clazz)) { + ep = (Endpoint) sec.getConfigurator().getEndpointInstance( + clazz); + } else { + ep = new PojoEndpointServer(); + // Need to make path params available to POJO + perSessionServerEndpointConfig.getUserProperties().put( + nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams); + } + } catch (InstantiationException e) { + throw new ServletException(e); + } + + wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest, + negotiatedExtensionsPhase2, subProtocol, null, pathParams, + req.isSecure()); + + wsHandler.init(null); + } + + + private static List<Transformation> createTransformations( + List<Extension> negotiatedExtensions) { + + TransformationFactory factory = TransformationFactory.getInstance(); + + LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences = + new LinkedHashMap<>(); + + // Result will likely be smaller than this + List<Transformation> result = new ArrayList<>(negotiatedExtensions.size()); + + for (Extension extension : negotiatedExtensions) { + List<List<Extension.Parameter>> preferences = + extensionPreferences.get(extension.getName()); + + if (preferences == null) { + preferences = new ArrayList<>(); + extensionPreferences.put(extension.getName(), preferences); + } + + preferences.add(extension.getParameters()); + } + + for (Map.Entry<String,List<List<Extension.Parameter>>> entry : + extensionPreferences.entrySet()) { + Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true); + if (transformation != null) { + result.add(transformation); + } + } + return result; + } + + + private static void append(StringBuilder sb, Extension extension) { + if (extension == null || extension.getName() == null || extension.getName().length() == 0) { + return; + } + + sb.append(extension.getName()); + + for (Extension.Parameter p : extension.getParameters()) { + sb.append(';'); + sb.append(p.getName()); + if (p.getValue() != null) { + sb.append('='); + sb.append(p.getValue()); + } + } + } + + + /* + * This only works for tokens. Quoted strings need more sophisticated + * parsing. + */ + private static boolean headerContainsToken(HttpServletRequest req, + String headerName, String target) { + Enumeration<String> headers = req.getHeaders(headerName); + while (headers.hasMoreElements()) { + String header = headers.nextElement(); + String[] tokens = header.split(","); + for (String token : tokens) { + if (target.equalsIgnoreCase(token.trim())) { + return true; + } + } + } + return false; + } + + + /* + * This only works for tokens. Quoted strings need more sophisticated + * parsing. + */ + private static List<String> getTokensFromHeader(HttpServletRequest req, + String headerName) { + List<String> result = new ArrayList<>(); + Enumeration<String> headers = req.getHeaders(headerName); + while (headers.hasMoreElements()) { + String header = headers.nextElement(); + String[] tokens = header.split(","); + for (String token : tokens) { + result.add(token.trim()); + } + } + return result; + } + + + private static String getWebSocketAccept(String key) { + byte[] digest = ConcurrentMessageDigest.digestSHA1( + key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT); + return Base64.encodeBase64String(digest); + } +} diff --git a/src/java/nginx/unit/websocket/server/UriTemplate.java b/src/java/nginx/unit/websocket/server/UriTemplate.java new file mode 100644 index 00000000..7877fac9 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/UriTemplate.java @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.websocket.DeploymentException; + +import org.apache.tomcat.util.res.StringManager; + +/** + * Extracts path parameters from URIs used to create web socket connections + * using the URI template defined for the associated Endpoint. + */ +public class UriTemplate { + + private static final StringManager sm = StringManager.getManager(UriTemplate.class); + + private final String normalized; + private final List<Segment> segments = new ArrayList<>(); + private final boolean hasParameters; + + + public UriTemplate(String path) throws DeploymentException { + + if (path == null || path.length() ==0 || !path.startsWith("/")) { + throw new DeploymentException( + sm.getString("uriTemplate.invalidPath", path)); + } + + StringBuilder normalized = new StringBuilder(path.length()); + Set<String> paramNames = new HashSet<>(); + + // Include empty segments. + String[] segments = path.split("/", -1); + int paramCount = 0; + int segmentCount = 0; + + for (int i = 0; i < segments.length; i++) { + String segment = segments[i]; + if (segment.length() == 0) { + if (i == 0 || (i == segments.length - 1 && paramCount == 0)) { + // Ignore the first empty segment as the path must always + // start with '/' + // Ending with a '/' is also OK for instances used for + // matches but not for parameterised templates. + continue; + } else { + // As per EG discussion, all other empty segments are + // invalid + throw new IllegalArgumentException(sm.getString( + "uriTemplate.emptySegment", path)); + } + } + normalized.append('/'); + int index = -1; + if (segment.startsWith("{") && segment.endsWith("}")) { + index = segmentCount; + segment = segment.substring(1, segment.length() - 1); + normalized.append('{'); + normalized.append(paramCount++); + normalized.append('}'); + if (!paramNames.add(segment)) { + throw new IllegalArgumentException(sm.getString( + "uriTemplate.duplicateParameter", segment)); + } + } else { + if (segment.contains("{") || segment.contains("}")) { + throw new IllegalArgumentException(sm.getString( + "uriTemplate.invalidSegment", segment, path)); + } + normalized.append(segment); + } + this.segments.add(new Segment(index, segment)); + segmentCount++; + } + + this.normalized = normalized.toString(); + this.hasParameters = paramCount > 0; + } + + + public Map<String,String> match(UriTemplate candidate) { + + Map<String,String> result = new HashMap<>(); + + // Should not happen but for safety + if (candidate.getSegmentCount() != getSegmentCount()) { + return null; + } + + Iterator<Segment> candidateSegments = + candidate.getSegments().iterator(); + Iterator<Segment> targetSegments = segments.iterator(); + + while (candidateSegments.hasNext()) { + Segment candidateSegment = candidateSegments.next(); + Segment targetSegment = targetSegments.next(); + + if (targetSegment.getParameterIndex() == -1) { + // Not a parameter - values must match + if (!targetSegment.getValue().equals( + candidateSegment.getValue())) { + // Not a match. Stop here + return null; + } + } else { + // Parameter + result.put(targetSegment.getValue(), + candidateSegment.getValue()); + } + } + + return result; + } + + + public boolean hasParameters() { + return hasParameters; + } + + + public int getSegmentCount() { + return segments.size(); + } + + + public String getNormalizedPath() { + return normalized; + } + + + private List<Segment> getSegments() { + return segments; + } + + + private static class Segment { + private final int parameterIndex; + private final String value; + + public Segment(int parameterIndex, String value) { + this.parameterIndex = parameterIndex; + this.value = value; + } + + + public int getParameterIndex() { + return parameterIndex; + } + + + public String getValue() { + return value; + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsContextListener.java b/src/java/nginx/unit/websocket/server/WsContextListener.java new file mode 100644 index 00000000..07137856 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsContextListener.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import javax.servlet.ServletContext; +import javax.servlet.ServletContextEvent; +import javax.servlet.ServletContextListener; + +/** + * In normal usage, this {@link ServletContextListener} does not need to be + * explicitly configured as the {@link WsSci} performs all the necessary + * bootstrap and installs this listener in the {@link ServletContext}. If the + * {@link WsSci} is disabled, this listener must be added manually to every + * {@link ServletContext} that uses WebSocket to bootstrap the + * {@link WsServerContainer} correctly. + */ +public class WsContextListener implements ServletContextListener { + + @Override + public void contextInitialized(ServletContextEvent sce) { + ServletContext sc = sce.getServletContext(); + // Don't trigger WebSocket initialization if a WebSocket Server + // Container is already present + if (sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE) == null) { + WsSci.init(sce.getServletContext(), false); + } + } + + @Override + public void contextDestroyed(ServletContextEvent sce) { + ServletContext sc = sce.getServletContext(); + Object obj = sc.getAttribute(Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + if (obj instanceof WsServerContainer) { + ((WsServerContainer) obj).destroy(); + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsFilter.java b/src/java/nginx/unit/websocket/server/WsFilter.java new file mode 100644 index 00000000..abea71fc --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsFilter.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; + +import javax.servlet.FilterChain; +import javax.servlet.GenericFilter; +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Handles the initial HTTP connection for WebSocket connections. + */ +public class WsFilter extends GenericFilter { + + private static final long serialVersionUID = 1L; + + private transient WsServerContainer sc; + + + @Override + public void init() throws ServletException { + sc = (WsServerContainer) getServletContext().getAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE); + } + + + @Override + public void doFilter(ServletRequest request, ServletResponse response, + FilterChain chain) throws IOException, ServletException { + + // This filter only needs to handle WebSocket upgrade requests + if (!sc.areEndpointsRegistered() || + !UpgradeUtil.isWebSocketUpgradeRequest(request, response)) { + chain.doFilter(request, response); + return; + } + + // HTTP request with an upgrade header for WebSocket present + HttpServletRequest req = (HttpServletRequest) request; + HttpServletResponse resp = (HttpServletResponse) response; + + // Check to see if this WebSocket implementation has a matching mapping + String path; + String pathInfo = req.getPathInfo(); + if (pathInfo == null) { + path = req.getServletPath(); + } else { + path = req.getServletPath() + pathInfo; + } + WsMappingResult mappingResult = sc.findMapping(path); + + if (mappingResult == null) { + // No endpoint registered for the requested path. Let the + // application handle it (it might redirect or forward for example) + chain.doFilter(request, response); + return; + } + + UpgradeUtil.doUpgrade(sc, req, resp, mappingResult.getConfig(), + mappingResult.getPathParams()); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java b/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java new file mode 100644 index 00000000..fa774302 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsHandshakeRequest.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.net.URI; +import java.net.URISyntaxException; +import java.security.Principal; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.servlet.http.HttpServletRequest; +import javax.websocket.server.HandshakeRequest; + +import org.apache.tomcat.util.collections.CaseInsensitiveKeyMap; +import org.apache.tomcat.util.res.StringManager; + +/** + * Represents the request that this session was opened under. + */ +public class WsHandshakeRequest implements HandshakeRequest { + + private static final StringManager sm = StringManager.getManager(WsHandshakeRequest.class); + + private final URI requestUri; + private final Map<String,List<String>> parameterMap; + private final String queryString; + private final Principal userPrincipal; + private final Map<String,List<String>> headers; + private final Object httpSession; + + private volatile HttpServletRequest request; + + + public WsHandshakeRequest(HttpServletRequest request, Map<String,String> pathParams) { + + this.request = request; + + queryString = request.getQueryString(); + userPrincipal = request.getUserPrincipal(); + httpSession = request.getSession(false); + requestUri = buildRequestUri(request); + + // ParameterMap + Map<String,String[]> originalParameters = request.getParameterMap(); + Map<String,List<String>> newParameters = + new HashMap<>(originalParameters.size()); + for (Entry<String,String[]> entry : originalParameters.entrySet()) { + newParameters.put(entry.getKey(), + Collections.unmodifiableList( + Arrays.asList(entry.getValue()))); + } + for (Entry<String,String> entry : pathParams.entrySet()) { + newParameters.put(entry.getKey(), + Collections.unmodifiableList( + Collections.singletonList(entry.getValue()))); + } + parameterMap = Collections.unmodifiableMap(newParameters); + + // Headers + Map<String,List<String>> newHeaders = new CaseInsensitiveKeyMap<>(); + + Enumeration<String> headerNames = request.getHeaderNames(); + while (headerNames.hasMoreElements()) { + String headerName = headerNames.nextElement(); + + newHeaders.put(headerName, Collections.unmodifiableList( + Collections.list(request.getHeaders(headerName)))); + } + + headers = Collections.unmodifiableMap(newHeaders); + } + + @Override + public URI getRequestURI() { + return requestUri; + } + + @Override + public Map<String,List<String>> getParameterMap() { + return parameterMap; + } + + @Override + public String getQueryString() { + return queryString; + } + + @Override + public Principal getUserPrincipal() { + return userPrincipal; + } + + @Override + public Map<String,List<String>> getHeaders() { + return headers; + } + + @Override + public boolean isUserInRole(String role) { + if (request == null) { + throw new IllegalStateException(); + } + + return request.isUserInRole(role); + } + + @Override + public Object getHttpSession() { + return httpSession; + } + + /** + * Called when the HandshakeRequest is no longer required. Since an instance + * of this class retains a reference to the current HttpServletRequest that + * reference needs to be cleared as the HttpServletRequest may be reused. + * + * There is no reason for instances of this class to be accessed once the + * handshake has been completed. + */ + void finished() { + request = null; + } + + + /* + * See RequestUtil.getRequestURL() + */ + private static URI buildRequestUri(HttpServletRequest req) { + + StringBuffer uri = new StringBuffer(); + String scheme = req.getScheme(); + int port = req.getServerPort(); + if (port < 0) { + // Work around java.net.URL bug + port = 80; + } + + if ("http".equals(scheme)) { + uri.append("ws"); + } else if ("https".equals(scheme)) { + uri.append("wss"); + } else { + // Should never happen + throw new IllegalArgumentException( + sm.getString("wsHandshakeRequest.unknownScheme", scheme)); + } + + uri.append("://"); + uri.append(req.getServerName()); + + if ((scheme.equals("http") && (port != 80)) + || (scheme.equals("https") && (port != 443))) { + uri.append(':'); + uri.append(port); + } + + uri.append(req.getRequestURI()); + + if (req.getQueryString() != null) { + uri.append("?"); + uri.append(req.getQueryString()); + } + + try { + return new URI(uri.toString()); + } catch (URISyntaxException e) { + // Should never happen + throw new IllegalArgumentException( + sm.getString("wsHandshakeRequest.invalidUri", uri.toString()), e); + } + } + + public Object getAttribute(String name) + { + return request != null ? request.getAttribute(name) : null; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java b/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java new file mode 100644 index 00000000..cc39ab73 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsHttpUpgradeHandler.java @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import javax.servlet.http.HttpSession; +import javax.servlet.http.HttpUpgradeHandler; +import javax.servlet.http.WebConnection; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; + +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.WsIOException; +import nginx.unit.websocket.WsSession; + +import nginx.unit.Request; + +/** + * Servlet 3.1 HTTP upgrade handler for WebSocket connections. + */ +public class WsHttpUpgradeHandler implements HttpUpgradeHandler { + + private final Log log = LogFactory.getLog(WsHttpUpgradeHandler.class); // must not be static + private static final StringManager sm = StringManager.getManager(WsHttpUpgradeHandler.class); + + private final ClassLoader applicationClassLoader; + + private Endpoint ep; + private EndpointConfig endpointConfig; + private WsServerContainer webSocketContainer; + private WsHandshakeRequest handshakeRequest; + private List<Extension> negotiatedExtensions; + private String subProtocol; + private Transformation transformation; + private Map<String,String> pathParameters; + private boolean secure; + private WebConnection connection; + private WsRemoteEndpointImplServer wsRemoteEndpointServer; + private WsSession wsSession; + + + public WsHttpUpgradeHandler() { + applicationClassLoader = Thread.currentThread().getContextClassLoader(); + } + + public void preInit(Endpoint ep, EndpointConfig endpointConfig, + WsServerContainer wsc, WsHandshakeRequest handshakeRequest, + List<Extension> negotiatedExtensionsPhase2, String subProtocol, + Transformation transformation, Map<String,String> pathParameters, + boolean secure) { + this.ep = ep; + this.endpointConfig = endpointConfig; + this.webSocketContainer = wsc; + this.handshakeRequest = handshakeRequest; + this.negotiatedExtensions = negotiatedExtensionsPhase2; + this.subProtocol = subProtocol; + this.transformation = transformation; + this.pathParameters = pathParameters; + this.secure = secure; + } + + + @Override + public void init(WebConnection connection) { + if (ep == null) { + throw new IllegalStateException( + sm.getString("wsHttpUpgradeHandler.noPreInit")); + } + + String httpSessionId = null; + Object session = handshakeRequest.getHttpSession(); + if (session != null ) { + httpSessionId = ((HttpSession) session).getId(); + } + + nginx.unit.Context.trace("UpgradeHandler.init(" + connection + ")"); + +/* + // Need to call onOpen using the web application's class loader + // Create the frame using the application's class loader so it can pick + // up application specific config from the ServerContainerImpl + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); +*/ + try { + Request r = (Request) handshakeRequest.getAttribute(Request.BARE); + + wsRemoteEndpointServer = new WsRemoteEndpointImplServer(webSocketContainer); + wsSession = new WsSession(ep, wsRemoteEndpointServer, + webSocketContainer, handshakeRequest.getRequestURI(), + handshakeRequest.getParameterMap(), + handshakeRequest.getQueryString(), + handshakeRequest.getUserPrincipal(), httpSessionId, + negotiatedExtensions, subProtocol, pathParameters, secure, + endpointConfig, r); + + ep.onOpen(wsSession, endpointConfig); + webSocketContainer.registerSession(ep, wsSession); + } catch (DeploymentException e) { + throw new IllegalArgumentException(e); +/* + } finally { + t.setContextClassLoader(cl); +*/ + } + } + + + + @Override + public void destroy() { + if (connection != null) { + try { + connection.close(); + } catch (Exception e) { + log.error(sm.getString("wsHttpUpgradeHandler.destroyFailed"), e); + } + } + } + + + private void onError(Throwable throwable) { + // Need to call onError using the web application's class loader + Thread t = Thread.currentThread(); + ClassLoader cl = t.getContextClassLoader(); + t.setContextClassLoader(applicationClassLoader); + try { + ep.onError(wsSession, throwable); + } finally { + t.setContextClassLoader(cl); + } + } + + + private void close(CloseReason cr) { + /* + * Any call to this method is a result of a problem reading from the + * client. At this point that state of the connection is unknown. + * Attempt to send a close frame to the client and then close the socket + * immediately. There is no point in waiting for a close frame from the + * client because there is no guarantee that we can recover from + * whatever messed up state the client put the connection into. + */ + wsSession.onClose(cr); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsMappingResult.java b/src/java/nginx/unit/websocket/server/WsMappingResult.java new file mode 100644 index 00000000..a7a4c022 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsMappingResult.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.Map; + +import javax.websocket.server.ServerEndpointConfig; + +class WsMappingResult { + + private final ServerEndpointConfig config; + private final Map<String,String> pathParams; + + + WsMappingResult(ServerEndpointConfig config, + Map<String,String> pathParams) { + this.config = config; + this.pathParams = pathParams; + } + + + ServerEndpointConfig getConfig() { + return config; + } + + + Map<String,String> getPathParams() { + return pathParams; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java b/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java new file mode 100644 index 00000000..2be050cb --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsPerSessionServerEndpointConfig.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import javax.websocket.Decoder; +import javax.websocket.Encoder; +import javax.websocket.Extension; +import javax.websocket.server.ServerEndpointConfig; + +/** + * Wraps the provided {@link ServerEndpointConfig} and provides a per session + * view - the difference being that the map returned by {@link + * #getUserProperties()} is unique to this instance rather than shared with the + * wrapped {@link ServerEndpointConfig}. + */ +class WsPerSessionServerEndpointConfig implements ServerEndpointConfig { + + private final ServerEndpointConfig perEndpointConfig; + private final Map<String,Object> perSessionUserProperties = + new ConcurrentHashMap<>(); + + WsPerSessionServerEndpointConfig(ServerEndpointConfig perEndpointConfig) { + this.perEndpointConfig = perEndpointConfig; + perSessionUserProperties.putAll(perEndpointConfig.getUserProperties()); + } + + @Override + public List<Class<? extends Encoder>> getEncoders() { + return perEndpointConfig.getEncoders(); + } + + @Override + public List<Class<? extends Decoder>> getDecoders() { + return perEndpointConfig.getDecoders(); + } + + @Override + public Map<String,Object> getUserProperties() { + return perSessionUserProperties; + } + + @Override + public Class<?> getEndpointClass() { + return perEndpointConfig.getEndpointClass(); + } + + @Override + public String getPath() { + return perEndpointConfig.getPath(); + } + + @Override + public List<String> getSubprotocols() { + return perEndpointConfig.getSubprotocols(); + } + + @Override + public List<Extension> getExtensions() { + return perEndpointConfig.getExtensions(); + } + + @Override + public Configurator getConfigurator() { + return perEndpointConfig.getConfigurator(); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java b/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java new file mode 100644 index 00000000..6d10a3be --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsRemoteEndpointImplServer.java @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.EOFException; +import java.io.IOException; +import java.net.SocketTimeoutException; +import java.nio.ByteBuffer; +import java.nio.channels.CompletionHandler; +import java.nio.channels.InterruptedByTimeoutException; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; + +import javax.websocket.SendHandler; +import javax.websocket.SendResult; + +import org.apache.juli.logging.Log; +import org.apache.juli.logging.LogFactory; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.WsRemoteEndpointImplBase; + +/** + * This is the server side {@link javax.websocket.RemoteEndpoint} implementation + * - i.e. what the server uses to send data to the client. + */ +public class WsRemoteEndpointImplServer extends WsRemoteEndpointImplBase { + + private static final StringManager sm = + StringManager.getManager(WsRemoteEndpointImplServer.class); + private final Log log = LogFactory.getLog(WsRemoteEndpointImplServer.class); // must not be static + + private volatile SendHandler handler = null; + private volatile ByteBuffer[] buffers = null; + + private volatile long timeoutExpiry = -1; + private volatile boolean close; + + public WsRemoteEndpointImplServer( + WsServerContainer serverContainer) { + } + + + @Override + protected final boolean isMasked() { + return false; + } + + @Override + protected void doWrite(SendHandler handler, long blockingWriteTimeoutExpiry, + ByteBuffer... buffers) { + } + + @Override + protected void doClose() { + if (handler != null) { + // close() can be triggered by a wide range of scenarios. It is far + // simpler just to always use a dispatch than it is to try and track + // whether or not this method was called by the same thread that + // triggered the write + clearHandler(new EOFException(), true); + } + } + + + protected long getTimeoutExpiry() { + return timeoutExpiry; + } + + + /* + * Currently this is only called from the background thread so we could just + * call clearHandler() with useDispatch == false but the method parameter + * was added in case other callers started to use this method to make sure + * that those callers think through what the correct value of useDispatch is + * for them. + */ + protected void onTimeout(boolean useDispatch) { + if (handler != null) { + clearHandler(new SocketTimeoutException(), useDispatch); + } + close(); + } + + + @Override + protected void setTransformation(Transformation transformation) { + // Overridden purely so it is visible to other classes in this package + super.setTransformation(transformation); + } + + + /** + * + * @param t The throwable associated with any error that + * occurred + * @param useDispatch Should {@link SendHandler#onResult(SendResult)} be + * called from a new thread, keeping in mind the + * requirements of + * {@link javax.websocket.RemoteEndpoint.Async} + */ + private void clearHandler(Throwable t, boolean useDispatch) { + // Setting the result marks this (partial) message as + // complete which means the next one may be sent which + // could update the value of the handler. Therefore, keep a + // local copy before signalling the end of the (partial) + // message. + SendHandler sh = handler; + handler = null; + buffers = null; + if (sh != null) { + if (useDispatch) { + OnResultRunnable r = new OnResultRunnable(sh, t); + } else { + if (t == null) { + sh.onResult(new SendResult()); + } else { + sh.onResult(new SendResult(t)); + } + } + } + } + + + private static class OnResultRunnable implements Runnable { + + private final SendHandler sh; + private final Throwable t; + + private OnResultRunnable(SendHandler sh, Throwable t) { + this.sh = sh; + this.t = t; + } + + @Override + public void run() { + if (t == null) { + sh.onResult(new SendResult()); + } else { + sh.onResult(new SendResult(t)); + } + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsSci.java b/src/java/nginx/unit/websocket/server/WsSci.java new file mode 100644 index 00000000..cdecce27 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsSci.java @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.lang.reflect.Modifier; +import java.util.HashSet; +import java.util.Set; + +import javax.servlet.ServletContainerInitializer; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.annotation.HandlesTypes; +import javax.websocket.ContainerProvider; +import javax.websocket.DeploymentException; +import javax.websocket.Endpoint; +import javax.websocket.server.ServerApplicationConfig; +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; + +/** + * Registers an interest in any class that is annotated with + * {@link ServerEndpoint} so that Endpoint can be published via the WebSocket + * server. + */ +@HandlesTypes({ServerEndpoint.class, ServerApplicationConfig.class, + Endpoint.class}) +public class WsSci implements ServletContainerInitializer { + + @Override + public void onStartup(Set<Class<?>> clazzes, ServletContext ctx) + throws ServletException { + + WsServerContainer sc = init(ctx, true); + + if (clazzes == null || clazzes.size() == 0) { + return; + } + + // Group the discovered classes by type + Set<ServerApplicationConfig> serverApplicationConfigs = new HashSet<>(); + Set<Class<? extends Endpoint>> scannedEndpointClazzes = new HashSet<>(); + Set<Class<?>> scannedPojoEndpoints = new HashSet<>(); + + try { + // wsPackage is "javax.websocket." + String wsPackage = ContainerProvider.class.getName(); + wsPackage = wsPackage.substring(0, wsPackage.lastIndexOf('.') + 1); + for (Class<?> clazz : clazzes) { + int modifiers = clazz.getModifiers(); + if (!Modifier.isPublic(modifiers) || + Modifier.isAbstract(modifiers)) { + // Non-public or abstract - skip it. + continue; + } + // Protect against scanning the WebSocket API JARs + if (clazz.getName().startsWith(wsPackage)) { + continue; + } + if (ServerApplicationConfig.class.isAssignableFrom(clazz)) { + serverApplicationConfigs.add( + (ServerApplicationConfig) clazz.getConstructor().newInstance()); + } + if (Endpoint.class.isAssignableFrom(clazz)) { + @SuppressWarnings("unchecked") + Class<? extends Endpoint> endpoint = + (Class<? extends Endpoint>) clazz; + scannedEndpointClazzes.add(endpoint); + } + if (clazz.isAnnotationPresent(ServerEndpoint.class)) { + scannedPojoEndpoints.add(clazz); + } + } + } catch (ReflectiveOperationException e) { + throw new ServletException(e); + } + + // Filter the results + Set<ServerEndpointConfig> filteredEndpointConfigs = new HashSet<>(); + Set<Class<?>> filteredPojoEndpoints = new HashSet<>(); + + if (serverApplicationConfigs.isEmpty()) { + filteredPojoEndpoints.addAll(scannedPojoEndpoints); + } else { + for (ServerApplicationConfig config : serverApplicationConfigs) { + Set<ServerEndpointConfig> configFilteredEndpoints = + config.getEndpointConfigs(scannedEndpointClazzes); + if (configFilteredEndpoints != null) { + filteredEndpointConfigs.addAll(configFilteredEndpoints); + } + Set<Class<?>> configFilteredPojos = + config.getAnnotatedEndpointClasses( + scannedPojoEndpoints); + if (configFilteredPojos != null) { + filteredPojoEndpoints.addAll(configFilteredPojos); + } + } + } + + try { + // Deploy endpoints + for (ServerEndpointConfig config : filteredEndpointConfigs) { + sc.addEndpoint(config); + } + // Deploy POJOs + for (Class<?> clazz : filteredPojoEndpoints) { + sc.addEndpoint(clazz); + } + } catch (DeploymentException e) { + throw new ServletException(e); + } + } + + + static WsServerContainer init(ServletContext servletContext, + boolean initBySciMechanism) { + + WsServerContainer sc = new WsServerContainer(servletContext); + + servletContext.setAttribute( + Constants.SERVER_CONTAINER_SERVLET_CONTEXT_ATTRIBUTE, sc); + + servletContext.addListener(new WsSessionListener(sc)); + // Can't register the ContextListener again if the ContextListener is + // calling this method + if (initBySciMechanism) { + servletContext.addListener(new WsContextListener()); + } + + return sc; + } +} diff --git a/src/java/nginx/unit/websocket/server/WsServerContainer.java b/src/java/nginx/unit/websocket/server/WsServerContainer.java new file mode 100644 index 00000000..069fc54f --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsServerContainer.java @@ -0,0 +1,470 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.EnumSet; +import java.util.Map; +import java.util.Set; +import java.util.SortedSet; +import java.util.TreeSet; +import java.util.concurrent.ConcurrentHashMap; + +import javax.servlet.DispatcherType; +import javax.servlet.FilterRegistration; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.CloseReason; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.DeploymentException; +import javax.websocket.Encoder; +import javax.websocket.Endpoint; +import javax.websocket.server.ServerContainer; +import javax.websocket.server.ServerEndpoint; +import javax.websocket.server.ServerEndpointConfig; +import javax.websocket.server.ServerEndpointConfig.Configurator; + +import org.apache.tomcat.InstanceManager; +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.WsSession; +import nginx.unit.websocket.WsWebSocketContainer; +import nginx.unit.websocket.pojo.PojoMethodMapping; + +/** + * Provides a per class loader (i.e. per web application) instance of a + * ServerContainer. Web application wide defaults may be configured by setting + * the following servlet context initialisation parameters to the desired + * values. + * <ul> + * <li>{@link Constants#BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li> + * <li>{@link Constants#TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM}</li> + * </ul> + */ +public class WsServerContainer extends WsWebSocketContainer + implements ServerContainer { + + private static final StringManager sm = StringManager.getManager(WsServerContainer.class); + + private static final CloseReason AUTHENTICATED_HTTP_SESSION_CLOSED = + new CloseReason(CloseCodes.VIOLATED_POLICY, + "This connection was established under an authenticated " + + "HTTP session that has ended."); + + private final ServletContext servletContext; + private final Map<String,ServerEndpointConfig> configExactMatchMap = + new ConcurrentHashMap<>(); + private final Map<Integer,SortedSet<TemplatePathMatch>> configTemplateMatchMap = + new ConcurrentHashMap<>(); + private volatile boolean enforceNoAddAfterHandshake = + nginx.unit.websocket.Constants.STRICT_SPEC_COMPLIANCE; + private volatile boolean addAllowed = true; + private final Map<String,Set<WsSession>> authenticatedSessions = new ConcurrentHashMap<>(); + private volatile boolean endpointsRegistered = false; + + WsServerContainer(ServletContext servletContext) { + + this.servletContext = servletContext; + setInstanceManager((InstanceManager) servletContext.getAttribute(InstanceManager.class.getName())); + + // Configure servlet context wide defaults + String value = servletContext.getInitParameter( + Constants.BINARY_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); + if (value != null) { + setDefaultMaxBinaryMessageBufferSize(Integer.parseInt(value)); + } + + value = servletContext.getInitParameter( + Constants.TEXT_BUFFER_SIZE_SERVLET_CONTEXT_INIT_PARAM); + if (value != null) { + setDefaultMaxTextMessageBufferSize(Integer.parseInt(value)); + } + + value = servletContext.getInitParameter( + Constants.ENFORCE_NO_ADD_AFTER_HANDSHAKE_CONTEXT_INIT_PARAM); + if (value != null) { + setEnforceNoAddAfterHandshake(Boolean.parseBoolean(value)); + } + + FilterRegistration.Dynamic fr = servletContext.addFilter( + "Tomcat WebSocket (JSR356) Filter", new WsFilter()); + fr.setAsyncSupported(true); + + EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.REQUEST, + DispatcherType.FORWARD); + + fr.addMappingForUrlPatterns(types, true, "/*"); + } + + + /** + * Published the provided endpoint implementation at the specified path with + * the specified configuration. {@link #WsServerContainer(ServletContext)} + * must be called before calling this method. + * + * @param sec The configuration to use when creating endpoint instances + * @throws DeploymentException if the endpoint cannot be published as + * requested + */ + @Override + public void addEndpoint(ServerEndpointConfig sec) + throws DeploymentException { + + if (enforceNoAddAfterHandshake && !addAllowed) { + throw new DeploymentException( + sm.getString("serverContainer.addNotAllowed")); + } + + if (servletContext == null) { + throw new DeploymentException( + sm.getString("serverContainer.servletContextMissing")); + } + String path = sec.getPath(); + + // Add method mapping to user properties + PojoMethodMapping methodMapping = new PojoMethodMapping(sec.getEndpointClass(), + sec.getDecoders(), path); + if (methodMapping.getOnClose() != null || methodMapping.getOnOpen() != null + || methodMapping.getOnError() != null || methodMapping.hasMessageHandlers()) { + sec.getUserProperties().put(nginx.unit.websocket.pojo.Constants.POJO_METHOD_MAPPING_KEY, + methodMapping); + } + + UriTemplate uriTemplate = new UriTemplate(path); + if (uriTemplate.hasParameters()) { + Integer key = Integer.valueOf(uriTemplate.getSegmentCount()); + SortedSet<TemplatePathMatch> templateMatches = + configTemplateMatchMap.get(key); + if (templateMatches == null) { + // Ensure that if concurrent threads execute this block they + // both end up using the same TreeSet instance + templateMatches = new TreeSet<>( + TemplatePathMatchComparator.getInstance()); + configTemplateMatchMap.putIfAbsent(key, templateMatches); + templateMatches = configTemplateMatchMap.get(key); + } + if (!templateMatches.add(new TemplatePathMatch(sec, uriTemplate))) { + // Duplicate uriTemplate; + throw new DeploymentException( + sm.getString("serverContainer.duplicatePaths", path, + sec.getEndpointClass(), + sec.getEndpointClass())); + } + } else { + // Exact match + ServerEndpointConfig old = configExactMatchMap.put(path, sec); + if (old != null) { + // Duplicate path mappings + throw new DeploymentException( + sm.getString("serverContainer.duplicatePaths", path, + old.getEndpointClass(), + sec.getEndpointClass())); + } + } + + endpointsRegistered = true; + } + + + /** + * Provides the equivalent of {@link #addEndpoint(ServerEndpointConfig)} + * for publishing plain old java objects (POJOs) that have been annotated as + * WebSocket endpoints. + * + * @param pojo The annotated POJO + */ + @Override + public void addEndpoint(Class<?> pojo) throws DeploymentException { + + ServerEndpoint annotation = pojo.getAnnotation(ServerEndpoint.class); + if (annotation == null) { + throw new DeploymentException( + sm.getString("serverContainer.missingAnnotation", + pojo.getName())); + } + String path = annotation.value(); + + // Validate encoders + validateEncoders(annotation.encoders()); + + // ServerEndpointConfig + ServerEndpointConfig sec; + Class<? extends Configurator> configuratorClazz = + annotation.configurator(); + Configurator configurator = null; + if (!configuratorClazz.equals(Configurator.class)) { + try { + configurator = annotation.configurator().getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "serverContainer.configuratorFail", + annotation.configurator().getName(), + pojo.getClass().getName()), e); + } + } + if (configurator == null) { + configurator = new nginx.unit.websocket.server.DefaultServerEndpointConfigurator(); + } + sec = ServerEndpointConfig.Builder.create(pojo, path). + decoders(Arrays.asList(annotation.decoders())). + encoders(Arrays.asList(annotation.encoders())). + subprotocols(Arrays.asList(annotation.subprotocols())). + configurator(configurator). + build(); + + addEndpoint(sec); + } + + + boolean areEndpointsRegistered() { + return endpointsRegistered; + } + + + /** + * Until the WebSocket specification provides such a mechanism, this Tomcat + * proprietary method is provided to enable applications to programmatically + * determine whether or not to upgrade an individual request to WebSocket. + * <p> + * Note: This method is not used by Tomcat but is used directly by + * third-party code and must not be removed. + * + * @param request The request object to be upgraded + * @param response The response object to be populated with the result of + * the upgrade + * @param sec The server endpoint to use to process the upgrade request + * @param pathParams The path parameters associated with the upgrade request + * + * @throws ServletException If a configuration error prevents the upgrade + * from taking place + * @throws IOException If an I/O error occurs during the upgrade process + */ + public void doUpgrade(HttpServletRequest request, + HttpServletResponse response, ServerEndpointConfig sec, + Map<String,String> pathParams) + throws ServletException, IOException { + UpgradeUtil.doUpgrade(this, request, response, sec, pathParams); + } + + + public WsMappingResult findMapping(String path) { + + // Prevent registering additional endpoints once the first attempt has + // been made to use one + if (addAllowed) { + addAllowed = false; + } + + // Check an exact match. Simple case as there are no templates. + ServerEndpointConfig sec = configExactMatchMap.get(path); + if (sec != null) { + return new WsMappingResult(sec, Collections.<String, String>emptyMap()); + } + + // No exact match. Need to look for template matches. + UriTemplate pathUriTemplate = null; + try { + pathUriTemplate = new UriTemplate(path); + } catch (DeploymentException e) { + // Path is not valid so can't be matched to a WebSocketEndpoint + return null; + } + + // Number of segments has to match + Integer key = Integer.valueOf(pathUriTemplate.getSegmentCount()); + SortedSet<TemplatePathMatch> templateMatches = + configTemplateMatchMap.get(key); + + if (templateMatches == null) { + // No templates with an equal number of segments so there will be + // no matches + return null; + } + + // List is in alphabetical order of normalised templates. + // Correct match is the first one that matches. + Map<String,String> pathParams = null; + for (TemplatePathMatch templateMatch : templateMatches) { + pathParams = templateMatch.getUriTemplate().match(pathUriTemplate); + if (pathParams != null) { + sec = templateMatch.getConfig(); + break; + } + } + + if (sec == null) { + // No match + return null; + } + + return new WsMappingResult(sec, pathParams); + } + + + + public boolean isEnforceNoAddAfterHandshake() { + return enforceNoAddAfterHandshake; + } + + + public void setEnforceNoAddAfterHandshake( + boolean enforceNoAddAfterHandshake) { + this.enforceNoAddAfterHandshake = enforceNoAddAfterHandshake; + } + + + /** + * {@inheritDoc} + * + * Overridden to make it visible to other classes in this package. + */ + @Override + protected void registerSession(Endpoint endpoint, WsSession wsSession) { + super.registerSession(endpoint, wsSession); + if (wsSession.isOpen() && + wsSession.getUserPrincipal() != null && + wsSession.getHttpSessionId() != null) { + registerAuthenticatedSession(wsSession, + wsSession.getHttpSessionId()); + } + } + + + /** + * {@inheritDoc} + * + * Overridden to make it visible to other classes in this package. + */ + @Override + protected void unregisterSession(Endpoint endpoint, WsSession wsSession) { + if (wsSession.getUserPrincipal() != null && + wsSession.getHttpSessionId() != null) { + unregisterAuthenticatedSession(wsSession, + wsSession.getHttpSessionId()); + } + super.unregisterSession(endpoint, wsSession); + } + + + private void registerAuthenticatedSession(WsSession wsSession, + String httpSessionId) { + Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId); + if (wsSessions == null) { + wsSessions = Collections.newSetFromMap( + new ConcurrentHashMap<WsSession,Boolean>()); + authenticatedSessions.putIfAbsent(httpSessionId, wsSessions); + wsSessions = authenticatedSessions.get(httpSessionId); + } + wsSessions.add(wsSession); + } + + + private void unregisterAuthenticatedSession(WsSession wsSession, + String httpSessionId) { + Set<WsSession> wsSessions = authenticatedSessions.get(httpSessionId); + // wsSessions will be null if the HTTP session has ended + if (wsSessions != null) { + wsSessions.remove(wsSession); + } + } + + + public void closeAuthenticatedSession(String httpSessionId) { + Set<WsSession> wsSessions = authenticatedSessions.remove(httpSessionId); + + if (wsSessions != null && !wsSessions.isEmpty()) { + for (WsSession wsSession : wsSessions) { + try { + wsSession.close(AUTHENTICATED_HTTP_SESSION_CLOSED); + } catch (IOException e) { + // Any IOExceptions during close will have been caught and the + // onError method called. + } + } + } + } + + + private static void validateEncoders(Class<? extends Encoder>[] encoders) + throws DeploymentException { + + for (Class<? extends Encoder> encoder : encoders) { + // Need to instantiate decoder to ensure it is valid and that + // deployment can be failed if it is not + @SuppressWarnings("unused") + Encoder instance; + try { + encoder.getConstructor().newInstance(); + } catch(ReflectiveOperationException e) { + throw new DeploymentException(sm.getString( + "serverContainer.encoderFail", encoder.getName()), e); + } + } + } + + + private static class TemplatePathMatch { + private final ServerEndpointConfig config; + private final UriTemplate uriTemplate; + + public TemplatePathMatch(ServerEndpointConfig config, + UriTemplate uriTemplate) { + this.config = config; + this.uriTemplate = uriTemplate; + } + + + public ServerEndpointConfig getConfig() { + return config; + } + + + public UriTemplate getUriTemplate() { + return uriTemplate; + } + } + + + /** + * This Comparator implementation is thread-safe so only create a single + * instance. + */ + private static class TemplatePathMatchComparator + implements Comparator<TemplatePathMatch> { + + private static final TemplatePathMatchComparator INSTANCE = + new TemplatePathMatchComparator(); + + public static TemplatePathMatchComparator getInstance() { + return INSTANCE; + } + + private TemplatePathMatchComparator() { + // Hide default constructor + } + + @Override + public int compare(TemplatePathMatch tpm1, TemplatePathMatch tpm2) { + return tpm1.getUriTemplate().getNormalizedPath().compareTo( + tpm2.getUriTemplate().getNormalizedPath()); + } + } +} diff --git a/src/java/nginx/unit/websocket/server/WsSessionListener.java b/src/java/nginx/unit/websocket/server/WsSessionListener.java new file mode 100644 index 00000000..fc2bc9c5 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsSessionListener.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import javax.servlet.http.HttpSessionEvent; +import javax.servlet.http.HttpSessionListener; + +public class WsSessionListener implements HttpSessionListener{ + + private final WsServerContainer wsServerContainer; + + + public WsSessionListener(WsServerContainer wsServerContainer) { + this.wsServerContainer = wsServerContainer; + } + + + @Override + public void sessionDestroyed(HttpSessionEvent se) { + wsServerContainer.closeAuthenticatedSession(se.getSession().getId()); + } +} diff --git a/src/java/nginx/unit/websocket/server/WsWriteTimeout.java b/src/java/nginx/unit/websocket/server/WsWriteTimeout.java new file mode 100644 index 00000000..2dfc4ab2 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/WsWriteTimeout.java @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.util.Comparator; +import java.util.Set; +import java.util.concurrent.ConcurrentSkipListSet; +import java.util.concurrent.atomic.AtomicInteger; + +import nginx.unit.websocket.BackgroundProcess; +import nginx.unit.websocket.BackgroundProcessManager; + +/** + * Provides timeouts for asynchronous web socket writes. On the server side we + * only have access to {@link javax.servlet.ServletOutputStream} and + * {@link javax.servlet.ServletInputStream} so there is no way to set a timeout + * for writes to the client. + */ +public class WsWriteTimeout implements BackgroundProcess { + + private final Set<WsRemoteEndpointImplServer> endpoints = + new ConcurrentSkipListSet<>(new EndpointComparator()); + private final AtomicInteger count = new AtomicInteger(0); + private int backgroundProcessCount = 0; + private volatile int processPeriod = 1; + + @Override + public void backgroundProcess() { + // This method gets called once a second. + backgroundProcessCount ++; + + if (backgroundProcessCount >= processPeriod) { + backgroundProcessCount = 0; + + long now = System.currentTimeMillis(); + for (WsRemoteEndpointImplServer endpoint : endpoints) { + if (endpoint.getTimeoutExpiry() < now) { + // Background thread, not the thread that triggered the + // write so no need to use a dispatch + endpoint.onTimeout(false); + } else { + // Endpoints are ordered by timeout expiry so if this point + // is reached there is no need to check the remaining + // endpoints + break; + } + } + } + } + + + @Override + public void setProcessPeriod(int period) { + this.processPeriod = period; + } + + + /** + * {@inheritDoc} + * + * The default value is 1 which means asynchronous write timeouts are + * processed every 1 second. + */ + @Override + public int getProcessPeriod() { + return processPeriod; + } + + + public void register(WsRemoteEndpointImplServer endpoint) { + boolean result = endpoints.add(endpoint); + if (result) { + int newCount = count.incrementAndGet(); + if (newCount == 1) { + BackgroundProcessManager.getInstance().register(this); + } + } + } + + + public void unregister(WsRemoteEndpointImplServer endpoint) { + boolean result = endpoints.remove(endpoint); + if (result) { + int newCount = count.decrementAndGet(); + if (newCount == 0) { + BackgroundProcessManager.getInstance().unregister(this); + } + } + } + + + /** + * Note: this comparator imposes orderings that are inconsistent with equals + */ + private static class EndpointComparator implements + Comparator<WsRemoteEndpointImplServer> { + + @Override + public int compare(WsRemoteEndpointImplServer o1, + WsRemoteEndpointImplServer o2) { + + long t1 = o1.getTimeoutExpiry(); + long t2 = o2.getTimeoutExpiry(); + + if (t1 < t2) { + return -1; + } else if (t1 == t2) { + return 0; + } else { + return 1; + } + } + } +} diff --git a/src/java/nginx/unit/websocket/server/package-info.java b/src/java/nginx/unit/websocket/server/package-info.java new file mode 100644 index 00000000..87bc85a3 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/package-info.java @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * Server-side specific implementation classes. These are in a separate package + * to make packaging a pure client JAR simpler. + */ +package nginx.unit.websocket.server; diff --git a/src/java/nxt_jni_Request.c b/src/java/nxt_jni_Request.c index 733290dd..2e9dce67 100644 --- a/src/java/nxt_jni_Request.c +++ b/src/java/nxt_jni_Request.c @@ -58,16 +58,28 @@ static jint JNICALL nxt_java_Request_getServerPort(JNIEnv *env, jclass cls, jlong req_ptr); static jboolean JNICALL nxt_java_Request_isSecure(JNIEnv *env, jclass cls, jlong req_ptr); +static void JNICALL nxt_java_Request_upgrade(JNIEnv *env, jclass cls, + jlong req_info_ptr); +static jboolean JNICALL nxt_java_Request_isUpgrade(JNIEnv *env, jclass cls, + jlong req_info_ptr); static void JNICALL nxt_java_Request_log(JNIEnv *env, jclass cls, jlong req_info_ptr, jstring msg, jint msg_len); static void JNICALL nxt_java_Request_trace(JNIEnv *env, jclass cls, jlong req_info_ptr, jstring msg, jint msg_len); static jobject JNICALL nxt_java_Request_getResponse(JNIEnv *env, jclass cls, jlong req_info_ptr); +static void JNICALL nxt_java_Request_sendWsFrameBuf(JNIEnv *env, jclass cls, + jlong req_info_ptr, jobject buf, jint pos, jint len, jbyte opCode, jboolean last); +static void JNICALL nxt_java_Request_sendWsFrameArr(JNIEnv *env, jclass cls, + jlong req_info_ptr, jarray arr, jint pos, jint len, jbyte opCode, jboolean last); +static void JNICALL nxt_java_Request_closeWs(JNIEnv *env, jclass cls, + jlong req_info_ptr); static jclass nxt_java_Request_class; static jmethodID nxt_java_Request_ctor; +static jmethodID nxt_java_Request_processWsFrame; +static jmethodID nxt_java_Request_closeWsSession; int @@ -91,6 +103,18 @@ nxt_java_initRequest(JNIEnv *env, jobject cl) return NXT_UNIT_ERROR; } + nxt_java_Request_processWsFrame = (*env)->GetMethodID(env, cls, "processWsFrame", "(Ljava/nio/ByteBuffer;BZ)V"); + if (nxt_java_Request_processWsFrame == NULL) { + (*env)->DeleteGlobalRef(env, cls); + return NXT_UNIT_ERROR; + } + + nxt_java_Request_closeWsSession = (*env)->GetMethodID(env, cls, "closeWsSession", "()V"); + if (nxt_java_Request_closeWsSession == NULL) { + (*env)->DeleteGlobalRef(env, cls); + return NXT_UNIT_ERROR; + } + JNINativeMethod request_methods[] = { { (char *) "getHeader", (char *) "(JLjava/lang/String;I)Ljava/lang/String;", @@ -172,6 +196,14 @@ nxt_java_initRequest(JNIEnv *env, jobject cl) (char *) "(J)Z", nxt_java_Request_isSecure }, + { (char *) "upgrade", + (char *) "(J)V", + nxt_java_Request_upgrade }, + + { (char *) "isUpgrade", + (char *) "(J)Z", + nxt_java_Request_isUpgrade }, + { (char *) "log", (char *) "(JLjava/lang/String;I)V", nxt_java_Request_log }, @@ -184,6 +216,18 @@ nxt_java_initRequest(JNIEnv *env, jobject cl) (char *) "(J)Lnginx/unit/Response;", nxt_java_Request_getResponse }, + { (char *) "sendWsFrame", + (char *) "(JLjava/nio/ByteBuffer;IIBZ)V", + nxt_java_Request_sendWsFrameBuf }, + + { (char *) "sendWsFrame", + (char *) "(J[BIIBZ)V", + nxt_java_Request_sendWsFrameArr }, + + { (char *) "closeWs", + (char *) "(J)V", + nxt_java_Request_closeWs }, + }; res = (*env)->RegisterNatives(env, nxt_java_Request_class, @@ -625,6 +669,32 @@ nxt_java_Request_isSecure(JNIEnv *env, jclass cls, jlong req_ptr) static void JNICALL +nxt_java_Request_upgrade(JNIEnv *env, jclass cls, jlong req_info_ptr) +{ + nxt_unit_request_info_t *req; + + req = nxt_jlong2ptr(req_info_ptr); + + if (!nxt_unit_response_is_init(req)) { + nxt_unit_response_init(req, 101, 0, 0); + } + + (void) nxt_unit_response_upgrade(req); +} + + +static jboolean JNICALL +nxt_java_Request_isUpgrade(JNIEnv *env, jclass cls, jlong req_info_ptr) +{ + nxt_unit_request_info_t *req; + + req = nxt_jlong2ptr(req_info_ptr); + + return nxt_unit_request_is_websocket_handshake(req); +} + + +static void JNICALL nxt_java_Request_log(JNIEnv *env, jclass cls, jlong req_info_ptr, jstring msg, jint msg_len) { @@ -677,3 +747,77 @@ nxt_java_Request_getResponse(JNIEnv *env, jclass cls, jlong req_info_ptr) return data->jresp; } + + +static void JNICALL +nxt_java_Request_sendWsFrameBuf(JNIEnv *env, jclass cls, + jlong req_info_ptr, jobject buf, jint pos, jint len, jbyte opCode, jboolean last) +{ + nxt_unit_request_info_t *req; + + req = nxt_jlong2ptr(req_info_ptr); + uint8_t *b = (*env)->GetDirectBufferAddress(env, buf); + + if (b != NULL) { + nxt_unit_websocket_send(req, opCode, last, b + pos, len); + + } else { + nxt_unit_req_debug(req, "sendWsFrameBuf: b == NULL"); + } +} + + +static void JNICALL +nxt_java_Request_sendWsFrameArr(JNIEnv *env, jclass cls, + jlong req_info_ptr, jarray arr, jint pos, jint len, jbyte opCode, jboolean last) +{ + nxt_unit_request_info_t *req; + + req = nxt_jlong2ptr(req_info_ptr); + uint8_t *b = (*env)->GetPrimitiveArrayCritical(env, arr, NULL); + + if (b != NULL) { + if (!nxt_unit_response_is_sent(req)) { + nxt_unit_response_send(req); + } + + nxt_unit_websocket_send(req, opCode, last, b + pos, len); + + (*env)->ReleasePrimitiveArrayCritical(env, arr, b, 0); + + } else { + nxt_unit_req_debug(req, "sendWsFrameArr: b == NULL"); + } +} + + +static void JNICALL +nxt_java_Request_closeWs(JNIEnv *env, jclass cls, jlong req_info_ptr) +{ + nxt_unit_request_info_t *req; + nxt_java_request_data_t *data; + + req = nxt_jlong2ptr(req_info_ptr); + + data = req->data; + + (*env)->DeleteGlobalRef(env, data->jresp); + (*env)->DeleteGlobalRef(env, data->jreq); + + nxt_unit_request_done(req, NXT_UNIT_OK); +} + + +void +nxt_java_Request_websocket(JNIEnv *env, jobject jreq, jobject jbuf, + uint8_t opcode, uint8_t fin) +{ + (*env)->CallVoidMethod(env, jreq, nxt_java_Request_processWsFrame, jbuf, opcode, fin); +} + + +void +nxt_java_Request_close(JNIEnv *env, jobject jreq) +{ + (*env)->CallVoidMethod(env, jreq, nxt_java_Request_closeWsSession); +} diff --git a/src/java/nxt_jni_Request.h b/src/java/nxt_jni_Request.h index 1c9c1428..9187d878 100644 --- a/src/java/nxt_jni_Request.h +++ b/src/java/nxt_jni_Request.h @@ -15,4 +15,9 @@ int nxt_java_initRequest(JNIEnv *env, jobject cl); jobject nxt_java_newRequest(JNIEnv *env, jobject ctx, nxt_unit_request_info_t *req); +void nxt_java_Request_websocket(JNIEnv *env, jobject jreq, jobject jbuf, + uint8_t opcode, uint8_t fin); + +void nxt_java_Request_close(JNIEnv *env, jobject jreq); + #endif /* _NXT_JAVA_REQUEST_H_INCLUDED_ */ diff --git a/src/nxt_application.h b/src/nxt_application.h index 7ff4bb11..2a1fa39e 100644 --- a/src/nxt_application.h +++ b/src/nxt_application.h @@ -88,6 +88,8 @@ struct nxt_common_app_conf_s { char *working_directory; nxt_conf_value_t *environment; + nxt_conf_value_t *isolation; + union { nxt_external_app_conf_t external; nxt_python_app_conf_t python; diff --git a/src/nxt_capability.c b/src/nxt_capability.c new file mode 100644 index 00000000..805faff6 --- /dev/null +++ b/src/nxt_capability.c @@ -0,0 +1,104 @@ +/* + * Copyright (C) Igor Sysoev + * Copyright (C) NGINX, Inc. + */ + +#include <nxt_main.h> + +#if (NXT_HAVE_LINUX_CAPABILITY) + +#include <linux/capability.h> +#include <sys/syscall.h> + +#define nxt_capget(hdrp, datap) \ + syscall(SYS_capget, hdrp, datap) +#define nxt_capset(hdrp, datap) \ + syscall(SYS_capset, hdrp, datap) + +#endif /* NXT_HAVE_LINUX_CAPABILITY */ + + +static nxt_int_t nxt_capability_specific_set(nxt_task_t *task, + nxt_capabilities_t *cap); + + +nxt_int_t +nxt_capability_set(nxt_task_t *task, nxt_capabilities_t *cap) +{ + nxt_assert(cap->setid == 0); + + if (geteuid() == 0) { + cap->setid = 1; + return NXT_OK; + } + + return nxt_capability_specific_set(task, cap); +} + + +#if (NXT_HAVE_LINUX_CAPABILITY) + +static uint32_t +nxt_capability_linux_get_version() +{ + struct __user_cap_header_struct hdr; + + hdr.version = _LINUX_CAPABILITY_VERSION; + hdr.pid = nxt_pid; + + nxt_capget(&hdr, NULL); + return hdr.version; +} + + +static nxt_int_t +nxt_capability_specific_set(nxt_task_t *task, nxt_capabilities_t *cap) +{ + struct __user_cap_data_struct *val, data[2]; + struct __user_cap_header_struct hdr; + + /* + * Linux capability v1 fills an u32 struct. + * Linux capability v2 and v3 fills an u64 struct. + * We allocate data[2] for compatibility, we waste 4 bytes on v1. + * + * This is safe as we only need to check CAP_SETUID and CAP_SETGID + * that resides in the first 32-bit chunk. + */ + + val = &data[0]; + + /* + * Ask the kernel the preferred capability version + * instead of using _LINUX_CAPABILITY_VERSION from header. + * This is safer when distributing a pre-compiled Unit binary. + */ + hdr.version = nxt_capability_linux_get_version(); + hdr.pid = nxt_pid; + + if (nxt_slow_path(nxt_capget(&hdr, val) == -1)) { + nxt_alert(task, "failed to get process capabilities: %E", nxt_errno); + return NXT_ERROR; + } + + if ((val->effective & (1 << CAP_SETUID)) == 0) { + return NXT_OK; + } + + if ((val->effective & (1 << CAP_SETGID)) == 0) { + return NXT_OK; + } + + cap->setid = 1; + return NXT_OK; +} + +#else + +static nxt_int_t +nxt_capability_specific_set(nxt_task_t *task, nxt_capabilities_t *cap) +{ + return NXT_OK; +} + +#endif diff --git a/src/nxt_capability.h b/src/nxt_capability.h new file mode 100644 index 00000000..60bbd5f8 --- /dev/null +++ b/src/nxt_capability.h @@ -0,0 +1,17 @@ +/* + * Copyright (C) Igor Sysoev + * Copyright (C) NGINX, Inc. + */ + +#ifndef _NXT_CAPABILITY_INCLUDED_ +#define _NXT_CAPABILITY_INCLUDED_ + +typedef struct { + uint8_t setid; /* 1 bit */ +} nxt_capabilities_t; + + +NXT_EXPORT nxt_int_t nxt_capability_set(nxt_task_t *task, + nxt_capabilities_t *cap); + +#endif /* _NXT_CAPABILITY_INCLUDED_ */ diff --git a/src/nxt_clone.c b/src/nxt_clone.c new file mode 100644 index 00000000..0fddd6c7 --- /dev/null +++ b/src/nxt_clone.c @@ -0,0 +1,263 @@ +/* + * Copyright (C) Igor Sysoev + * Copyright (C) NGINX, Inc. + */ + +#include <nxt_main.h> +#include <sys/types.h> +#include <nxt_conf.h> +#include <nxt_clone.h> + +#if (NXT_HAVE_CLONE) + +pid_t +nxt_clone(nxt_int_t flags) +{ +#if defined(__s390x__) || defined(__s390__) || defined(__CRIS__) + return syscall(__NR_clone, NULL, flags); +#else + return syscall(__NR_clone, flags, NULL); +#endif +} + +#endif + + +#if (NXT_HAVE_CLONE_NEWUSER) + +/* map uid 65534 to unit pid */ +#define NXT_DEFAULT_UNPRIV_MAP "65534 %d 1" + +nxt_int_t nxt_clone_proc_setgroups(nxt_task_t *task, pid_t child_pid, + const char *str); +nxt_int_t nxt_clone_proc_map_set(nxt_task_t *task, const char* mapfile, + pid_t pid, nxt_int_t defval, nxt_conf_value_t *mapobj); +nxt_int_t nxt_clone_proc_map_write(nxt_task_t *task, const char *mapfile, + pid_t pid, u_char *mapinfo); + + +typedef struct { + nxt_int_t container; + nxt_int_t host; + nxt_int_t size; +} nxt_clone_procmap_t; + + +nxt_int_t +nxt_clone_proc_setgroups(nxt_task_t *task, pid_t child_pid, const char *str) +{ + int fd, n; + u_char *p, *end; + u_char path[PATH_MAX]; + + end = path + PATH_MAX; + p = nxt_sprintf(path, end, "/proc/%d/setgroups", child_pid); + *p = '\0'; + + if (nxt_slow_path(p == end)) { + nxt_alert(task, "error write past the buffer: %s", path); + return NXT_ERROR; + } + + fd = open((char *)path, O_RDWR); + + if (fd == -1) { + /* + * If the /proc/pid/setgroups doesn't exists, we are + * safe to set uid/gid maps. But if the error is anything + * other than ENOENT, then we should abort and let user know. + */ + + if (errno != ENOENT) { + nxt_alert(task, "open(%s): %E", path, nxt_errno); + return NXT_ERROR; + } + + return NXT_OK; + } + + n = write(fd, str, strlen(str)); + close(fd); + + if (nxt_slow_path(n == -1)) { + nxt_alert(task, "write(%s): %E", path, nxt_errno); + return NXT_ERROR; + } + + return NXT_OK; +} + + +nxt_int_t +nxt_clone_proc_map_write(nxt_task_t *task, const char *mapfile, pid_t pid, + u_char *mapinfo) +{ + int len, mapfd; + u_char *p, *end; + ssize_t n; + u_char buf[256]; + + end = buf + sizeof(buf); + + p = nxt_sprintf(buf, end, "/proc/%d/%s", pid, mapfile); + if (nxt_slow_path(p == end)) { + nxt_alert(task, "writing past the buffer"); + return NXT_ERROR; + } + + *p = '\0'; + + mapfd = open((char*)buf, O_RDWR); + if (nxt_slow_path(mapfd == -1)) { + nxt_alert(task, "failed to open proc map (%s) %E", buf, nxt_errno); + return NXT_ERROR; + } + + len = nxt_strlen(mapinfo); + + n = write(mapfd, (char *)mapinfo, len); + if (nxt_slow_path(n != len)) { + + if (n == -1 && nxt_errno == EINVAL) { + nxt_alert(task, "failed to write %s: Check kernel maximum " \ + "allowed lines %E", buf, nxt_errno); + + } else { + nxt_alert(task, "failed to write proc map (%s) %E", buf, + nxt_errno); + } + + return NXT_ERROR; + } + + return NXT_OK; +} + + +nxt_int_t +nxt_clone_proc_map_set(nxt_task_t *task, const char* mapfile, pid_t pid, + nxt_int_t defval, nxt_conf_value_t *mapobj) +{ + u_char *p, *end, *mapinfo; + nxt_int_t container, host, size; + nxt_int_t ret, len, count, i; + nxt_conf_value_t *obj, *value; + + static nxt_str_t str_cont = nxt_string("container"); + static nxt_str_t str_host = nxt_string("host"); + static nxt_str_t str_size = nxt_string("size"); + + /* + * uid_map one-entry size: + * alloc space for 3 numbers (32bit) plus 2 spaces and \n. + */ + len = sizeof(u_char) * (10 + 10 + 10 + 2 + 1); + + if (mapobj != NULL) { + count = nxt_conf_array_elements_count(mapobj); + + if (count == 0) { + goto default_map; + } + + len = len * count + 1; + + mapinfo = nxt_malloc(len); + if (nxt_slow_path(mapinfo == NULL)) { + nxt_alert(task, "failed to allocate uid_map buffer"); + return NXT_ERROR; + } + + p = mapinfo; + end = mapinfo + len; + + for (i = 0; i < count; i++) { + obj = nxt_conf_get_array_element(mapobj, i); + + value = nxt_conf_get_object_member(obj, &str_cont, NULL); + container = nxt_conf_get_integer(value); + + value = nxt_conf_get_object_member(obj, &str_host, NULL); + host = nxt_conf_get_integer(value); + + value = nxt_conf_get_object_member(obj, &str_size, NULL); + size = nxt_conf_get_integer(value); + + p = nxt_sprintf(p, end, "%d %d %d", container, host, size); + if (nxt_slow_path(p == end)) { + nxt_alert(task, "write past the uid_map buffer"); + nxt_free(mapinfo); + return NXT_ERROR; + } + + if (i+1 < count) { + *p++ = '\n'; + + } else { + *p = '\0'; + } + } + + } else { + +default_map: + + mapinfo = nxt_malloc(len); + if (nxt_slow_path(mapinfo == NULL)) { + nxt_alert(task, "failed to allocate uid_map buffer"); + return NXT_ERROR; + } + + end = mapinfo + len; + p = nxt_sprintf(mapinfo, end, NXT_DEFAULT_UNPRIV_MAP, defval); + *p = '\0'; + + if (nxt_slow_path(p == end)) { + nxt_alert(task, "write past the %s buffer", mapfile); + nxt_free(mapinfo); + return NXT_ERROR; + } + } + + ret = nxt_clone_proc_map_write(task, mapfile, pid, mapinfo); + + nxt_free(mapinfo); + + return ret; +} + + +nxt_int_t +nxt_clone_proc_map(nxt_task_t *task, pid_t pid, nxt_process_clone_t *clone) +{ + nxt_int_t ret; + nxt_int_t uid, gid; + const char *rule; + nxt_runtime_t *rt; + + rt = task->thread->runtime; + uid = geteuid(); + gid = getegid(); + + rule = rt->capabilities.setid ? "allow" : "deny"; + + ret = nxt_clone_proc_map_set(task, "uid_map", pid, uid, clone->uidmap); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + + ret = nxt_clone_proc_setgroups(task, pid, rule); + if (nxt_slow_path(ret != NXT_OK)) { + nxt_alert(task, "failed to write /proc/%d/setgroups", pid); + return NXT_ERROR; + } + + ret = nxt_clone_proc_map_set(task, "gid_map", pid, gid, clone->gidmap); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + + return NXT_OK; +} + +#endif diff --git a/src/nxt_clone.h b/src/nxt_clone.h new file mode 100644 index 00000000..50dec0b4 --- /dev/null +++ b/src/nxt_clone.h @@ -0,0 +1,17 @@ +/* + * Copyright (C) Igor Sysoev + * Copyright (C) NGINX, Inc. + */ + +#ifndef _NXT_CLONE_INCLUDED_ +#define _NXT_CLONE_INCLUDED_ + + +pid_t nxt_clone(nxt_int_t flags); + +#if (NXT_HAVE_CLONE_NEWUSER) +nxt_int_t nxt_clone_proc_map(nxt_task_t *task, pid_t pid, + nxt_process_clone_t *clone); +#endif + +#endif /* _NXT_CLONE_INCLUDED_ */ diff --git a/src/nxt_conf.c b/src/nxt_conf.c index 57870838..59eddd77 100644 --- a/src/nxt_conf.c +++ b/src/nxt_conf.c @@ -16,6 +16,8 @@ #define NXT_CONF_MAX_SHORT_STRING 14 #define NXT_CONF_MAX_STRING NXT_INT32_T_MAX +#define NXT_CONF_MAX_TOKEN_LEN 256 + typedef enum { NXT_CONF_VALUE_NULL = 0, @@ -90,6 +92,17 @@ struct nxt_conf_op_s { }; +typedef struct { + u_char *start; + u_char *end; + nxt_bool_t last; + u_char buf[NXT_CONF_MAX_TOKEN_LEN]; +} nxt_conf_path_parse_t; + + +static nxt_int_t nxt_conf_path_next_token(nxt_conf_path_parse_t *parse, + nxt_str_t *token); + static u_char *nxt_conf_json_skip_space(u_char *start, u_char *end); static u_char *nxt_conf_json_parse_value(nxt_mp_t *mp, nxt_conf_value_t *value, u_char *start, u_char *end, nxt_conf_json_error_t *error); @@ -402,22 +415,11 @@ nxt_conf_type(nxt_conf_value_t *value) } -typedef struct { - u_char *start; - u_char *end; - nxt_bool_t last; -} nxt_conf_path_parse_t; - - -static void nxt_conf_path_next_token(nxt_conf_path_parse_t *parse, - nxt_str_t *token); - - nxt_conf_value_t * nxt_conf_get_path(nxt_conf_value_t *value, nxt_str_t *path) { nxt_str_t token; - nxt_int_t index; + nxt_int_t ret, index; nxt_conf_path_parse_t parse; parse.start = path->start; @@ -425,7 +427,10 @@ nxt_conf_get_path(nxt_conf_value_t *value, nxt_str_t *path) parse.last = 0; do { - nxt_conf_path_next_token(&parse, &token); + ret = nxt_conf_path_next_token(&parse, &token); + if (nxt_slow_path(ret != NXT_OK)) { + return NULL; + } if (token.length == 0) { @@ -466,24 +471,38 @@ nxt_conf_get_path(nxt_conf_value_t *value, nxt_str_t *path) } -static void +static nxt_int_t nxt_conf_path_next_token(nxt_conf_path_parse_t *parse, nxt_str_t *token) { - u_char *p, *end; + u_char *p, *start, *end; + size_t length; - end = parse->end; - p = parse->start + 1; + start = parse->start + 1; - token->start = p; + p = start; - while (p < end && *p != '/') { + while (p < parse->end && *p != '/') { p++; } parse->start = p; - parse->last = (p >= end); + parse->last = (p >= parse->end); + + length = p - start; + + if (nxt_slow_path(length > NXT_CONF_MAX_TOKEN_LEN)) { + return NXT_ERROR; + } + + end = nxt_decode_uri(parse->buf, start, length); + if (nxt_slow_path(end == NULL)) { + return NXT_ERROR; + } - token->length = p - token->start; + token->length = end - parse->buf; + token->start = parse->buf; + + return NXT_OK; } @@ -742,7 +761,7 @@ nxt_conf_op_compile(nxt_mp_t *mp, nxt_conf_op_t **ops, nxt_conf_value_t *root, nxt_str_t *path, nxt_conf_value_t *value, nxt_bool_t add) { nxt_str_t token; - nxt_int_t index; + nxt_int_t ret, index; nxt_conf_op_t *op, **parent; nxt_conf_value_t *node; nxt_conf_path_parse_t parse; @@ -763,7 +782,10 @@ nxt_conf_op_compile(nxt_mp_t *mp, nxt_conf_op_t **ops, nxt_conf_value_t *root, *parent = op; parent = (nxt_conf_op_t **) &op->ctx; - nxt_conf_path_next_token(&parse, &token); + ret = nxt_conf_path_next_token(&parse, &token); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_CONF_OP_ERROR; + } switch (root->type) { @@ -857,7 +879,10 @@ nxt_conf_op_compile(nxt_mp_t *mp, nxt_conf_op_t **ops, nxt_conf_value_t *root, return NXT_CONF_OP_ERROR; } - nxt_conf_set_string(&member->name, &token); + ret = nxt_conf_set_string_dup(&member->name, mp, &token); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_CONF_OP_ERROR; + } member->value = *value; diff --git a/src/nxt_conf.h b/src/nxt_conf.h index 2435b0e2..725a6c95 100644 --- a/src/nxt_conf.h +++ b/src/nxt_conf.h @@ -71,6 +71,7 @@ typedef struct { nxt_conf_value_t *conf; nxt_mp_t *pool; nxt_str_t error; + void *ctx; } nxt_conf_validation_t; diff --git a/src/nxt_conf_validation.c b/src/nxt_conf_validation.c index ca8ec62e..c934b10b 100644 --- a/src/nxt_conf_validation.c +++ b/src/nxt_conf_validation.c @@ -8,6 +8,7 @@ #include <nxt_conf.h> #include <nxt_cert.h> #include <nxt_router.h> +#include <nxt_http.h> typedef enum { @@ -39,15 +40,18 @@ typedef nxt_int_t (*nxt_conf_vldt_member_t)(nxt_conf_validation_t *vldt, nxt_conf_value_t *value); typedef nxt_int_t (*nxt_conf_vldt_element_t)(nxt_conf_validation_t *vldt, nxt_conf_value_t *value); -typedef nxt_int_t (*nxt_conf_vldt_system_t)(nxt_conf_validation_t *vldt, - char *name); - static nxt_int_t nxt_conf_vldt_type(nxt_conf_validation_t *vldt, nxt_str_t *name, nxt_conf_value_t *value, nxt_conf_vldt_type_t type); static nxt_int_t nxt_conf_vldt_error(nxt_conf_validation_t *vldt, const char *fmt, ...); +static nxt_int_t nxt_conf_vldt_mtypes(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value, void *data); +static nxt_int_t nxt_conf_vldt_mtypes_type(nxt_conf_validation_t *vldt, + nxt_str_t *name, nxt_conf_value_t *value); +static nxt_int_t nxt_conf_vldt_mtypes_extension(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value); static nxt_int_t nxt_conf_vldt_listener(nxt_conf_validation_t *vldt, nxt_str_t *name, nxt_conf_value_t *value); #if (NXT_TLS) @@ -86,10 +90,6 @@ static nxt_int_t nxt_conf_vldt_object_iterator(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, void *data); static nxt_int_t nxt_conf_vldt_array_iterator(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, void *data); -static nxt_int_t nxt_conf_vldt_system(nxt_conf_validation_t *vldt, - nxt_conf_value_t *value, void *data); -static nxt_int_t nxt_conf_vldt_user(nxt_conf_validation_t *vldt, char *name); -static nxt_int_t nxt_conf_vldt_group(nxt_conf_validation_t *vldt, char *name); static nxt_int_t nxt_conf_vldt_environment(nxt_conf_validation_t *vldt, nxt_str_t *name, nxt_conf_value_t *value); static nxt_int_t nxt_conf_vldt_argument(nxt_conf_validation_t *vldt, @@ -101,6 +101,21 @@ static nxt_int_t nxt_conf_vldt_java_classpath(nxt_conf_validation_t *vldt, static nxt_int_t nxt_conf_vldt_java_option(nxt_conf_validation_t *vldt, nxt_conf_value_t *value); +static nxt_int_t +nxt_conf_vldt_isolation(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, + void *data); +static nxt_int_t +nxt_conf_vldt_clone_namespaces(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value, void *data); + +#if (NXT_HAVE_CLONE_NEWUSER) +static nxt_int_t nxt_conf_vldt_clone_procmap(nxt_conf_validation_t *vldt, + const char* mapfile, nxt_conf_value_t *value); +static nxt_int_t nxt_conf_vldt_clone_uidmap(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value); +static nxt_int_t nxt_conf_vldt_clone_gidmap(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value); +#endif static nxt_conf_vldt_object_t nxt_conf_vldt_websocket_members[] = { { nxt_string("read_timeout"), @@ -122,6 +137,16 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_websocket_members[] = { }; +static nxt_conf_vldt_object_t nxt_conf_vldt_static_members[] = { + { nxt_string("mime_types"), + NXT_CONF_VLDT_OBJECT, + &nxt_conf_vldt_mtypes, + NULL }, + + NXT_CONF_VLDT_END +}; + + static nxt_conf_vldt_object_t nxt_conf_vldt_http_members[] = { { nxt_string("header_read_timeout"), NXT_CONF_VLDT_INTEGER, @@ -153,6 +178,11 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_http_members[] = { &nxt_conf_vldt_object, (void *) &nxt_conf_vldt_websocket_members }, + { nxt_string("static"), + NXT_CONF_VLDT_OBJECT, + &nxt_conf_vldt_object, + (void *) &nxt_conf_vldt_static_members }, + NXT_CONF_VLDT_END }; @@ -281,6 +311,11 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_action_members[] = { &nxt_conf_vldt_pass, NULL }, + { nxt_string("share"), + NXT_CONF_VLDT_STRING, + NULL, + NULL }, + NXT_CONF_VLDT_END }; @@ -340,6 +375,100 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_app_processes_members[] = { }; +static nxt_conf_vldt_object_t nxt_conf_vldt_app_namespaces_members[] = { + +#if (NXT_HAVE_CLONE_NEWUSER) + { nxt_string("credential"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWPID) + { nxt_string("pid"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWNET) + { nxt_string("network"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWNS) + { nxt_string("mount"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWUTS) + { nxt_string("uname"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + +#if (NXT_HAVE_CLONE_NEWCGROUP) + { nxt_string("cgroup"), + NXT_CONF_VLDT_BOOLEAN, + NULL, + NULL }, +#endif + + NXT_CONF_VLDT_END +}; + + +#if (NXT_HAVE_CLONE_NEWUSER) + +static nxt_conf_vldt_object_t nxt_conf_vldt_app_procmap_members[] = { + { nxt_string("container"), + NXT_CONF_VLDT_INTEGER, + NULL, + NULL }, + + { nxt_string("host"), + NXT_CONF_VLDT_INTEGER, + NULL, + NULL }, + + { nxt_string("size"), + NXT_CONF_VLDT_INTEGER, + NULL, + NULL }, +}; + +#endif + + +static nxt_conf_vldt_object_t nxt_conf_vldt_app_isolation_members[] = { + { nxt_string("namespaces"), + NXT_CONF_VLDT_OBJECT, + &nxt_conf_vldt_clone_namespaces, + (void *) &nxt_conf_vldt_app_namespaces_members }, + +#if (NXT_HAVE_CLONE_NEWUSER) + + { nxt_string("uidmap"), + NXT_CONF_VLDT_ARRAY, + &nxt_conf_vldt_array_iterator, + (void *) &nxt_conf_vldt_clone_uidmap }, + + { nxt_string("gidmap"), + NXT_CONF_VLDT_ARRAY, + &nxt_conf_vldt_array_iterator, + (void *) &nxt_conf_vldt_clone_gidmap }, + +#endif + + NXT_CONF_VLDT_END +}; + + static nxt_conf_vldt_object_t nxt_conf_vldt_common_members[] = { { nxt_string("type"), NXT_CONF_VLDT_STRING, @@ -358,13 +487,13 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_common_members[] = { { nxt_string("user"), NXT_CONF_VLDT_STRING, - nxt_conf_vldt_system, - (void *) &nxt_conf_vldt_user }, + NULL, + NULL }, { nxt_string("group"), NXT_CONF_VLDT_STRING, - nxt_conf_vldt_system, - (void *) &nxt_conf_vldt_group }, + NULL, + NULL }, { nxt_string("working_directory"), NXT_CONF_VLDT_STRING, @@ -376,6 +505,11 @@ static nxt_conf_vldt_object_t nxt_conf_vldt_common_members[] = { &nxt_conf_vldt_object_iterator, (void *) &nxt_conf_vldt_environment }, + { nxt_string("isolation"), + NXT_CONF_VLDT_OBJECT, + &nxt_conf_vldt_isolation, + (void *) &nxt_conf_vldt_app_isolation_members }, + NXT_CONF_VLDT_END }; @@ -628,6 +762,108 @@ nxt_conf_vldt_error(nxt_conf_validation_t *vldt, const char *fmt, ...) } +typedef struct { + nxt_mp_t *pool; + nxt_str_t *type; + nxt_lvlhsh_t hash; +} nxt_conf_vldt_mtypes_ctx_t; + + +static nxt_int_t +nxt_conf_vldt_mtypes(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, + void *data) +{ + nxt_int_t ret; + nxt_conf_vldt_mtypes_ctx_t ctx; + + ctx.pool = nxt_mp_create(1024, 128, 256, 32); + if (nxt_slow_path(ctx.pool == NULL)) { + return NXT_ERROR; + } + + nxt_lvlhsh_init(&ctx.hash); + + vldt->ctx = &ctx; + + ret = nxt_conf_vldt_object_iterator(vldt, value, + &nxt_conf_vldt_mtypes_type); + + vldt->ctx = NULL; + + nxt_mp_destroy(ctx.pool); + + return ret; +} + + +static nxt_int_t +nxt_conf_vldt_mtypes_type(nxt_conf_validation_t *vldt, nxt_str_t *name, + nxt_conf_value_t *value) +{ + nxt_int_t ret; + nxt_conf_vldt_mtypes_ctx_t *ctx; + + ret = nxt_conf_vldt_type(vldt, name, value, + NXT_CONF_VLDT_STRING|NXT_CONF_VLDT_ARRAY); + if (ret != NXT_OK) { + return ret; + } + + ctx = vldt->ctx; + + ctx->type = nxt_mp_get(ctx->pool, sizeof(nxt_str_t)); + if (nxt_slow_path(ctx->type == NULL)) { + return NXT_ERROR; + } + + *ctx->type = *name; + + if (nxt_conf_type(value) == NXT_CONF_ARRAY) { + return nxt_conf_vldt_array_iterator(vldt, value, + &nxt_conf_vldt_mtypes_extension); + } + + /* NXT_CONF_STRING */ + + return nxt_conf_vldt_mtypes_extension(vldt, value); +} + + +static nxt_int_t +nxt_conf_vldt_mtypes_extension(nxt_conf_validation_t *vldt, + nxt_conf_value_t *value) +{ + nxt_str_t ext, *dup_type; + nxt_conf_vldt_mtypes_ctx_t *ctx; + + ctx = vldt->ctx; + + if (nxt_conf_type(value) != NXT_CONF_STRING) { + return nxt_conf_vldt_error(vldt, "The \"%V\" MIME type array must " + "contain only strings.", ctx->type); + } + + nxt_conf_get_string(value, &ext); + + if (ext.length == 0) { + return nxt_conf_vldt_error(vldt, "An empty file extension for " + "the \"%V\" MIME type.", ctx->type); + } + + dup_type = nxt_http_static_mtypes_hash_find(&ctx->hash, &ext); + + if (dup_type != NULL) { + return nxt_conf_vldt_error(vldt, "The \"%V\" file extension has been " + "declared for \"%V\" and \"%V\" " + "MIME types at the same time.", + &ext, dup_type, ctx->type); + } + + return nxt_http_static_mtypes_hash_add(ctx->pool, &ctx->hash, + &ext, ctx->type); +} + + static nxt_int_t nxt_conf_vldt_listener(nxt_conf_validation_t *vldt, nxt_str_t *name, nxt_conf_value_t *value) @@ -1252,106 +1488,168 @@ nxt_conf_vldt_array_iterator(nxt_conf_validation_t *vldt, static nxt_int_t -nxt_conf_vldt_system(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, - void *data) +nxt_conf_vldt_environment(nxt_conf_validation_t *vldt, nxt_str_t *name, + nxt_conf_value_t *value) { - size_t length; - nxt_str_t name; - nxt_conf_vldt_system_t validator; - char string[32]; + nxt_str_t str; - /* The cast is required by Sun C. */ - validator = (nxt_conf_vldt_system_t) data; + if (name->length == 0) { + return nxt_conf_vldt_error(vldt, + "The environment name must not be empty."); + } - nxt_conf_get_string(value, &name); + if (nxt_memchr(name->start, '\0', name->length) != NULL) { + return nxt_conf_vldt_error(vldt, "The environment name must not " + "contain null character."); + } - length = name.length + 1; - length = nxt_min(length, sizeof(string)); + if (nxt_memchr(name->start, '=', name->length) != NULL) { + return nxt_conf_vldt_error(vldt, "The environment name must not " + "contain '=' character."); + } - nxt_cpystrn((u_char *) string, name.start, length); + if (nxt_conf_type(value) != NXT_CONF_STRING) { + return nxt_conf_vldt_error(vldt, "The \"%V\" environment value must be " + "a string.", name); + } + + nxt_conf_get_string(value, &str); - return validator(vldt, string); + if (nxt_memchr(str.start, '\0', str.length) != NULL) { + return nxt_conf_vldt_error(vldt, "The \"%V\" environment value must " + "not contain null character.", name); + } + + return NXT_OK; } static nxt_int_t -nxt_conf_vldt_user(nxt_conf_validation_t *vldt, char *user) +nxt_conf_vldt_clone_namespaces(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, + void *data) { - struct passwd *pwd; + return nxt_conf_vldt_object(vldt, value, data); +} - nxt_errno = 0; - pwd = getpwnam(user); +static nxt_int_t +nxt_conf_vldt_isolation(nxt_conf_validation_t *vldt, nxt_conf_value_t *value, + void *data) +{ + return nxt_conf_vldt_object(vldt, value, data); +} - if (pwd != NULL) { - return NXT_OK; - } - if (nxt_errno == 0) { - return nxt_conf_vldt_error(vldt, "User \"%s\" is not found.", user); - } +#if (NXT_HAVE_CLONE_NEWUSER) + +typedef struct { + nxt_int_t container; + nxt_int_t host; + nxt_int_t size; +} nxt_conf_vldt_clone_procmap_conf_t; - return NXT_ERROR; -} + +static nxt_conf_map_t nxt_conf_vldt_clone_procmap_conf_map[] = { + { + nxt_string("container"), + NXT_CONF_MAP_INT32, + offsetof(nxt_conf_vldt_clone_procmap_conf_t, container), + }, + + { + nxt_string("host"), + NXT_CONF_MAP_INT32, + offsetof(nxt_conf_vldt_clone_procmap_conf_t, host), + }, + + { + nxt_string("size"), + NXT_CONF_MAP_INT32, + offsetof(nxt_conf_vldt_clone_procmap_conf_t, size), + }, + +}; static nxt_int_t -nxt_conf_vldt_group(nxt_conf_validation_t *vldt, char *group) +nxt_conf_vldt_clone_procmap(nxt_conf_validation_t *vldt, const char *mapfile, + nxt_conf_value_t *value) { - struct group *grp; + nxt_int_t ret; + nxt_conf_vldt_clone_procmap_conf_t procmap; - nxt_errno = 0; + procmap.container = -1; + procmap.host = -1; + procmap.size = -1; - grp = getgrnam(group); + ret = nxt_conf_map_object(vldt->pool, value, + nxt_conf_vldt_clone_procmap_conf_map, + nxt_nitems(nxt_conf_vldt_clone_procmap_conf_map), + &procmap); + if (ret != NXT_OK) { + return ret; + } - if (grp != NULL) { - return NXT_OK; + if (procmap.container == -1) { + return nxt_conf_vldt_error(vldt, "The %s requires the " + "\"container\" field set.", mapfile); } - if (nxt_errno == 0) { - return nxt_conf_vldt_error(vldt, "Group \"%s\" is not found.", group); + if (procmap.host == -1) { + return nxt_conf_vldt_error(vldt, "The %s requires the " + "\"host\" field set.", mapfile); } - return NXT_ERROR; + if (procmap.size == -1) { + return nxt_conf_vldt_error(vldt, "The %s requires the " + "\"size\" field set.", mapfile); + } + + return NXT_OK; } static nxt_int_t -nxt_conf_vldt_environment(nxt_conf_validation_t *vldt, nxt_str_t *name, - nxt_conf_value_t *value) +nxt_conf_vldt_clone_uidmap(nxt_conf_validation_t *vldt, nxt_conf_value_t *value) { - nxt_str_t str; + nxt_int_t ret; - if (name->length == 0) { - return nxt_conf_vldt_error(vldt, - "The environment name must not be empty."); + if (nxt_conf_type(value) != NXT_CONF_OBJECT) { + return nxt_conf_vldt_error(vldt, "The \"uidmap\" array " + "must contain only object values."); } - if (nxt_memchr(name->start, '\0', name->length) != NULL) { - return nxt_conf_vldt_error(vldt, "The environment name must not " - "contain null character."); + ret = nxt_conf_vldt_object(vldt, value, + (void *) nxt_conf_vldt_app_procmap_members); + if (nxt_slow_path(ret != NXT_OK)) { + return ret; } - if (nxt_memchr(name->start, '=', name->length) != NULL) { - return nxt_conf_vldt_error(vldt, "The environment name must not " - "contain '=' character."); - } + return nxt_conf_vldt_clone_procmap(vldt, "uid_map", value); +} - if (nxt_conf_type(value) != NXT_CONF_STRING) { - return nxt_conf_vldt_error(vldt, "The \"%V\" environment value must be " - "a string.", name); - } - nxt_conf_get_string(value, &str); +static nxt_int_t +nxt_conf_vldt_clone_gidmap(nxt_conf_validation_t *vldt, nxt_conf_value_t *value) +{ + nxt_int_t ret; - if (nxt_memchr(str.start, '\0', str.length) != NULL) { - return nxt_conf_vldt_error(vldt, "The \"%V\" environment value must " - "not contain null character.", name); + if (nxt_conf_type(value) != NXT_CONF_OBJECT) { + return nxt_conf_vldt_error(vldt, "The \"gidmap\" array " + "must contain only object values."); } - return NXT_OK; + ret = nxt_conf_vldt_object(vldt, value, + (void *) nxt_conf_vldt_app_procmap_members); + if (nxt_slow_path(ret != NXT_OK)) { + return ret; + } + + return nxt_conf_vldt_clone_procmap(vldt, "gid_map", value); } +#endif + static nxt_int_t nxt_conf_vldt_argument(nxt_conf_validation_t *vldt, nxt_conf_value_t *value) diff --git a/src/nxt_controller.c b/src/nxt_controller.c index 49afbe46..86ba1246 100644 --- a/src/nxt_controller.c +++ b/src/nxt_controller.c @@ -478,6 +478,8 @@ nxt_controller_conn_init(nxt_task_t *task, void *obj, void *data) return; } + r->parser.encoded_slashes = 1; + b = nxt_buf_mem_alloc(c->mem_pool, 1024, 0); if (nxt_slow_path(b == NULL)) { nxt_controller_conn_free(task, c, NULL); diff --git a/src/nxt_h1proto.c b/src/nxt_h1proto.c index a60bdb36..39be4315 100644 --- a/src/nxt_h1proto.c +++ b/src/nxt_h1proto.c @@ -45,7 +45,7 @@ static void nxt_h1p_conn_request_body_read(nxt_task_t *task, void *obj, void *data); static void nxt_h1p_request_local_addr(nxt_task_t *task, nxt_http_request_t *r); static void nxt_h1p_request_header_send(nxt_task_t *task, - nxt_http_request_t *r); + nxt_http_request_t *r, nxt_work_handler_t body_handler); static void nxt_h1p_request_send(nxt_task_t *task, nxt_http_request_t *r, nxt_buf_t *out); static nxt_buf_t *nxt_h1p_chunk_create(nxt_task_t *task, nxt_http_request_t *r, @@ -830,8 +830,7 @@ nxt_h1p_request_body_read(nxt_task_t *task, nxt_http_request_t *r) ready: - nxt_work_queue_add(&task->thread->engine->fast_work_queue, - r->state->ready_handler, task, r, NULL); + r->state->ready_handler(task, r, NULL); return; @@ -884,8 +883,7 @@ nxt_h1p_conn_request_body_read(nxt_task_t *task, void *obj, void *data) c->read = NULL; r = h1p->request; - nxt_work_queue_add(&engine->fast_work_queue, r->state->ready_handler, - task, r, NULL); + r->state->ready_handler(task, r, NULL); } } @@ -996,7 +994,8 @@ static const nxt_str_t nxt_http_server_error[] = { #define UNKNOWN_STATUS_LENGTH nxt_length("HTTP/1.1 65536\r\n") static void -nxt_h1p_request_header_send(nxt_task_t *task, nxt_http_request_t *r) +nxt_h1p_request_header_send(nxt_task_t *task, nxt_http_request_t *r, + nxt_work_handler_t body_handler) { u_char *p; size_t size; @@ -1079,6 +1078,7 @@ nxt_h1p_request_header_send(nxt_task_t *task, nxt_http_request_t *r) if (http11) { if (n != NXT_HTTP_NOT_MODIFIED && n != NXT_HTTP_NO_CONTENT + && body_handler != NULL && !h1p->websocket) { h1p->chunked = 1; @@ -1165,6 +1165,19 @@ nxt_h1p_request_header_send(nxt_task_t *task, nxt_http_request_t *r) c->write = header; c->write_state = &nxt_h1p_request_send_state; + if (body_handler != NULL) { + /* + * The body handler will run before c->io->write() handler, + * because the latter was inqueued by nxt_conn_write() + * in engine->write_work_queue. + */ + nxt_work_queue_add(&task->thread->engine->fast_work_queue, + body_handler, task, r, NULL); + + } else { + header->next = nxt_http_buf_last(r); + } + nxt_conn_write(task->thread->engine, c); if (h1p->websocket) { diff --git a/src/nxt_h1proto_websocket.c b/src/nxt_h1proto_websocket.c index dd9b6848..13754be0 100644 --- a/src/nxt_h1proto_websocket.c +++ b/src/nxt_h1proto_websocket.c @@ -417,16 +417,13 @@ nxt_h1p_conn_ws_frame_process(nxt_task_t *task, nxt_conn_t *c, uint8_t *p, *mask; uint16_t code; nxt_http_request_t *r; - nxt_event_engine_t *engine; - engine = task->thread->engine; r = h1p->request; c->read = NULL; if (nxt_slow_path(wsh->opcode == NXT_WEBSOCKET_OP_PING)) { - nxt_work_queue_add(&engine->fast_work_queue, nxt_h1p_conn_ws_pong, - task, r, NULL); + nxt_h1p_conn_ws_pong(task, r, NULL); return; } @@ -451,8 +448,7 @@ nxt_h1p_conn_ws_frame_process(nxt_task_t *task, nxt_conn_t *c, h1p->websocket_closed = 1; } - nxt_work_queue_add(&engine->fast_work_queue, r->state->ready_handler, - task, r, NULL); + r->state->ready_handler(task, r, NULL); } diff --git a/src/nxt_http.h b/src/nxt_http.h index ac1eedcf..560b7310 100644 --- a/src/nxt_http.h +++ b/src/nxt_http.h @@ -24,7 +24,9 @@ typedef enum { NXT_HTTP_NOT_MODIFIED = 304, NXT_HTTP_BAD_REQUEST = 400, + NXT_HTTP_FORBIDDEN = 403, NXT_HTTP_NOT_FOUND = 404, + NXT_HTTP_METHOD_NOT_ALLOWED = 405, NXT_HTTP_REQUEST_TIMEOUT = 408, NXT_HTTP_LENGTH_REQUIRED = 411, NXT_HTTP_PAYLOAD_TOO_LARGE = 413, @@ -175,7 +177,8 @@ struct nxt_http_pass_s { typedef struct { void (*body_read)(nxt_task_t *task, nxt_http_request_t *r); void (*local_addr)(nxt_task_t *task, nxt_http_request_t *r); - void (*header_send)(nxt_task_t *task, nxt_http_request_t *r); + void (*header_send)(nxt_task_t *task, nxt_http_request_t *r, + nxt_work_handler_t body_handler); void (*send)(nxt_task_t *task, nxt_http_request_t *r, nxt_buf_t *out); nxt_off_t (*body_bytes_sent)(nxt_task_t *task, nxt_http_proto_t proto); void (*discard)(nxt_task_t *task, nxt_http_request_t *r, nxt_buf_t *last); @@ -186,6 +189,25 @@ typedef struct { } nxt_http_proto_table_t; +#define NXT_HTTP_DATE_LEN nxt_length("Wed, 31 Dec 1986 16:40:00 GMT") + +nxt_inline u_char * +nxt_http_date(u_char *buf, struct tm *tm) +{ + static const char *week[] = { "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", + "Sat" }; + + static const char *month[] = { "Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" }; + + return nxt_sprintf(buf, buf + NXT_HTTP_DATE_LEN, + "%s, %02d %s %4d %02d:%02d:%02d GMT", + week[tm->tm_wday], tm->tm_mday, + month[tm->tm_mon], tm->tm_year + 1900, + tm->tm_hour, tm->tm_min, tm->tm_sec); +} + + nxt_int_t nxt_http_init(nxt_task_t *task, nxt_runtime_t *rt); nxt_int_t nxt_h1p_init(nxt_task_t *task, nxt_runtime_t *rt); nxt_int_t nxt_http_response_hash_init(nxt_task_t *task, nxt_runtime_t *rt); @@ -195,7 +217,8 @@ nxt_http_request_t *nxt_http_request_create(nxt_task_t *task); void nxt_http_request_error(nxt_task_t *task, nxt_http_request_t *r, nxt_http_status_t status); void nxt_http_request_read_body(nxt_task_t *task, nxt_http_request_t *r); -void nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r); +void nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r, + nxt_work_handler_t body_handler); void nxt_http_request_ws_frame_start(nxt_task_t *task, nxt_http_request_t *r, nxt_buf_t *ws_frame); void nxt_http_request_send(nxt_task_t *task, nxt_http_request_t *r, @@ -223,6 +246,14 @@ nxt_http_pass_t *nxt_http_pass_application(nxt_task_t *task, void nxt_http_routes_cleanup(nxt_task_t *task, nxt_http_routes_t *routes); void nxt_http_pass_cleanup(nxt_task_t *task, nxt_http_pass_t *pass); +nxt_http_pass_t *nxt_http_static_handler(nxt_task_t *task, + nxt_http_request_t *r, nxt_http_pass_t *pass); +nxt_int_t nxt_http_static_mtypes_init(nxt_mp_t *mp, nxt_lvlhsh_t *hash); +nxt_int_t nxt_http_static_mtypes_hash_add(nxt_mp_t *mp, nxt_lvlhsh_t *hash, + nxt_str_t *extension, nxt_str_t *type); +nxt_str_t *nxt_http_static_mtypes_hash_find(nxt_lvlhsh_t *hash, + nxt_str_t *extension); + nxt_http_pass_t *nxt_http_request_application(nxt_task_t *task, nxt_http_request_t *r, nxt_http_pass_t *pass); diff --git a/src/nxt_http_error.c b/src/nxt_http_error.c index c7c7e81a..1dcd8783 100644 --- a/src/nxt_http_error.c +++ b/src/nxt_http_error.c @@ -51,12 +51,10 @@ nxt_http_request_error(nxt_task_t *task, nxt_http_request_t *r, r->resp.content_length = NULL; r->resp.content_length_n = nxt_length(error); - nxt_http_request_header_send(task, r); - r->state = &nxt_http_request_send_error_body_state; - nxt_work_queue_add(&task->thread->engine->fast_work_queue, - nxt_http_request_send_error_body, task, r, NULL); + nxt_http_request_header_send(task, r, nxt_http_request_send_error_body); + return; fail: diff --git a/src/nxt_http_parse.c b/src/nxt_http_parse.c index 8b4bf47c..5b009d96 100644 --- a/src/nxt_http_parse.c +++ b/src/nxt_http_parse.c @@ -47,7 +47,6 @@ typedef enum { NXT_HTTP_TARGET_DOT, /* . */ NXT_HTTP_TARGET_ARGS_MARK, /* ? */ NXT_HTTP_TARGET_QUOTE_MARK, /* % */ - NXT_HTTP_TARGET_PLUS, /* + */ } nxt_http_target_traps_e; @@ -57,7 +56,7 @@ static const uint8_t nxt_http_target_chars[256] nxt_aligned(64) = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, /* \s ! " # $ % & ' ( ) * + , - . / */ - 1, 0, 0, 2, 0, 8, 0, 0, 0, 0, 0, 9, 0, 0, 6, 5, + 1, 0, 0, 2, 0, 8, 0, 0, 0, 0, 0, 0, 0, 0, 6, 5, /* 0 1 2 3 4 5 6 7 8 9 : ; < = > ? */ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 7, @@ -163,8 +162,9 @@ static nxt_int_t nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, u_char *end) { - u_char *p, ch, *after_slash; + u_char *p, ch, *after_slash, *exten, *args; nxt_int_t rc; + nxt_bool_t rest; nxt_http_ver_t ver; nxt_http_target_traps_e trap; @@ -255,6 +255,11 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, rp->target_start = p; after_slash = p + 1; + exten = NULL; + args = NULL; + rest = 0; + +continue_target: for ( ;; ) { p++; @@ -269,8 +274,7 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, } after_slash = p + 1; - - rp->exten_start = NULL; + exten = NULL; continue; case NXT_HTTP_TARGET_DOT: @@ -279,11 +283,11 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, goto rest_of_target; } - rp->exten_start = p + 1; + exten = p + 1; continue; case NXT_HTTP_TARGET_ARGS_MARK: - rp->args_start = p + 1; + args = p + 1; goto rest_of_target; case NXT_HTTP_TARGET_SPACE: @@ -294,10 +298,6 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, rp->quoted_target = 1; goto rest_of_target; - case NXT_HTTP_TARGET_PLUS: - rp->plus_in_target = 1; - continue; - case NXT_HTTP_TARGET_HASH: rp->complex_target = 1; goto rest_of_target; @@ -316,6 +316,8 @@ nxt_http_parse_request_line(nxt_http_request_parse_t *rp, u_char **pos, rest_of_target: + rest = 1; + for ( ;; ) { p++; @@ -382,7 +384,12 @@ space_after_target: } rp->space_in_target = 1; - goto rest_of_target; + + if (rest) { + goto rest_of_target; + } + + goto continue_target; } /* " HTTP/1.1\r\n" or " HTTP/1.1\n" */ @@ -396,7 +403,12 @@ space_after_target: } rp->space_in_target = 1; - goto rest_of_target; + + if (rest) { + goto rest_of_target; + } + + goto continue_target; } nxt_memcpy(ver.str, &p[1], 8); @@ -437,20 +449,19 @@ space_after_target: rp->path.start = rp->target_start; - if (rp->args_start != NULL) { - rp->path.length = rp->args_start - rp->target_start - 1; + if (args != NULL) { + rp->path.length = args - rp->target_start - 1; - rp->args.start = rp->args_start; - rp->args.length = rp->target_end - rp->args_start; + rp->args.length = rp->target_end - args; + rp->args.start = args; } else { rp->path.length = rp->target_end - rp->target_start; } - if (rp->exten_start) { - rp->exten.length = rp->path.start + rp->path.length - - rp->exten_start; - rp->exten.start = rp->exten_start; + if (exten != NULL) { + rp->exten.length = (rp->path.start + rp->path.length) - exten; + rp->exten.start = exten; } return nxt_http_parse_field_name(rp, pos, end); @@ -835,7 +846,8 @@ static const uint8_t nxt_http_normal[32] nxt_aligned(32) = { static nxt_int_t nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) { - u_char *p, *u, c, ch, high; + u_char *p, *u, c, ch, high, *exten, *args; + enum { sw_normal = 0, sw_slash, @@ -852,7 +864,6 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) p = rp->target_start; u = nxt_mp_alloc(rp->mem_pool, rp->target_end - p + 1); - if (nxt_slow_path(u == NULL)) { return NXT_ERROR; } @@ -861,8 +872,8 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) rp->path.start = u; high = '\0'; - rp->exten_start = NULL; - rp->args_start = NULL; + exten = NULL; + args = NULL; while (p < rp->target_end) { @@ -881,7 +892,7 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) switch (ch) { case '/': - rp->exten_start = NULL; + exten = NULL; state = sw_slash; *u++ = ch; continue; @@ -890,17 +901,14 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_quoted; continue; case '?': - rp->args_start = p; + args = p; goto args; case '#': goto done; case '.': - rp->exten_start = u + 1; + exten = u + 1; *u++ = ch; continue; - case '+': - rp->plus_in_target = 1; - /* Fall through. */ default: *u++ = ch; continue; @@ -928,13 +936,10 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_quoted; continue; case '?': - rp->args_start = p; + args = p; goto args; case '#': goto done; - case '+': - rp->plus_in_target = 1; - /* Fall through. */ default: state = sw_normal; *u++ = ch; @@ -965,13 +970,10 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_quoted; continue; case '?': - rp->args_start = p; + args = p; goto args; case '#': goto done; - case '+': - rp->plus_in_target = 1; - /* Fall through. */ default: state = sw_normal; *u++ = ch; @@ -1009,13 +1011,10 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_quoted; continue; case '?': - rp->args_start = p; + args = p; goto args; case '#': goto done; - case '+': - rp->plus_in_target = 1; - /* Fall through. */ default: state = sw_normal; *u++ = ch; @@ -1046,12 +1045,25 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) if (ch >= '0' && ch <= '9') { ch = (u_char) ((high << 4) + ch - '0'); - if (ch == '%' || ch == '#') { + if (ch == '%') { state = sw_normal; - *u++ = ch; + *u++ = '%'; + + if (rp->encoded_slashes) { + *u++ = '2'; + *u++ = '5'; + } + continue; + } - } else if (ch == '\0') { + if (ch == '#') { + state = sw_normal; + *u++ = '#'; + continue; + } + + if (ch == '\0') { return NXT_HTTP_PARSE_INVALID; } @@ -1067,9 +1079,14 @@ nxt_http_parse_complex_target(nxt_http_request_parse_t *rp) state = sw_normal; *u++ = ch; continue; + } - } else if (ch == '+') { - rp->plus_in_target = 1; + if (ch == '/' && rp->encoded_slashes) { + state = sw_normal; + *u++ = '%'; + *u++ = '2'; + *u++ = p[-1]; /* 'f' or 'F' */ + continue; } state = saved_state; @@ -1092,18 +1109,18 @@ args: } } - if (rp->args_start != NULL) { - rp->args.length = p - rp->args_start; - rp->args.start = rp->args_start; + if (args != NULL) { + rp->args.length = p - args; + rp->args.start = args; } done: rp->path.length = u - rp->path.start; - if (rp->exten_start) { - rp->exten.length = u - rp->exten_start; - rp->exten.start = rp->exten_start; + if (exten) { + rp->exten.length = u - exten; + rp->exten.start = exten; } return NXT_OK; diff --git a/src/nxt_http_parse.h b/src/nxt_http_parse.h index c5b11bf3..a307ea73 100644 --- a/src/nxt_http_parse.h +++ b/src/nxt_http_parse.h @@ -37,14 +37,10 @@ struct nxt_http_request_parse_s { nxt_int_t (*handler)(nxt_http_request_parse_t *rp, u_char **pos, u_char *end); - size_t offset; - nxt_str_t method; u_char *target_start; u_char *target_end; - u_char *exten_start; - u_char *args_start; nxt_str_t path; nxt_str_t args; @@ -61,13 +57,14 @@ struct nxt_http_request_parse_s { uint32_t field_hash; /* target with "/." */ - unsigned complex_target:1; + uint8_t complex_target; /* 1 bit */ /* target with "%" */ - unsigned quoted_target:1; + uint8_t quoted_target; /* 1 bit */ /* target with " " */ - unsigned space_in_target:1; - /* target with "+" */ - unsigned plus_in_target:1; + uint8_t space_in_target; /* 1 bit */ + + /* Preserve encoded '/' (%2F) and '%' (%25). */ + uint8_t encoded_slashes; /* 1 bit */ }; @@ -113,7 +110,7 @@ nxt_int_t nxt_http_fields_process(nxt_list_t *fields, nxt_lvlhsh_t *hash, void *ctx); -const nxt_lvlhsh_proto_t nxt_http_fields_hash_proto; +extern const nxt_lvlhsh_proto_t nxt_http_fields_hash_proto; nxt_inline nxt_int_t nxt_http_field_process(nxt_http_field_t *field, nxt_lvlhsh_t *hash, void *ctx) diff --git a/src/nxt_http_request.c b/src/nxt_http_request.c index 916004d2..a18a02e7 100644 --- a/src/nxt_http_request.c +++ b/src/nxt_http_request.c @@ -17,8 +17,8 @@ static void nxt_http_request_mem_buf_completion(nxt_task_t *task, void *obj, void *data); static void nxt_http_request_done(nxt_task_t *task, void *obj, void *data); -static u_char *nxt_http_date(u_char *buf, nxt_realtime_t *now, struct tm *tm, - size_t size, const char *format); +static u_char *nxt_http_date_cache_handler(u_char *buf, nxt_realtime_t *now, + struct tm *tm, size_t size, const char *format); static const nxt_http_request_state_t nxt_http_request_init_state; @@ -27,9 +27,9 @@ static const nxt_http_request_state_t nxt_http_request_body_state; nxt_time_string_t nxt_http_date_cache = { (nxt_atomic_uint_t) -1, - nxt_http_date, - "%s, %02d %s %4d %02d:%02d:%02d GMT", - nxt_length("Wed, 31 Dec 1986 16:40:00 GMT"), + nxt_http_date_cache_handler, + NULL, + NXT_HTTP_DATE_LEN, NXT_THREAD_TIME_GMT, NXT_THREAD_TIME_SEC, }; @@ -369,7 +369,8 @@ nxt_http_request_read_body(nxt_task_t *task, nxt_http_request_t *r) void -nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r) +nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r, + nxt_work_handler_t body_handler) { u_char *p, *end; nxt_http_field_t *server, *date, *content_length; @@ -430,7 +431,7 @@ nxt_http_request_header_send(nxt_task_t *task, nxt_http_request_t *r) } if (nxt_fast_path(r->proto.any != NULL)) { - nxt_http_proto[r->protocol].header_send(task, r); + nxt_http_proto[r->protocol].header_send(task, r, body_handler); } return; @@ -577,17 +578,8 @@ nxt_http_request_close_handler(nxt_task_t *task, void *obj, void *data) static u_char * -nxt_http_date(u_char *buf, nxt_realtime_t *now, struct tm *tm, size_t size, - const char *format) +nxt_http_date_cache_handler(u_char *buf, nxt_realtime_t *now, struct tm *tm, + size_t size, const char *format) { - static const char *week[] = { "Sun", "Mon", "Tue", "Wed", "Thu", "Fri", - "Sat" }; - - static const char *month[] = { "Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec" }; - - return nxt_sprintf(buf, buf + size, format, - week[tm->tm_wday], tm->tm_mday, - month[tm->tm_mon], tm->tm_year + 1900, - tm->tm_hour, tm->tm_min, tm->tm_sec); + return nxt_http_date(buf, tm); } diff --git a/src/nxt_http_route.c b/src/nxt_http_route.c index 0b665573..c3c11faa 100644 --- a/src/nxt_http_route.c +++ b/src/nxt_http_route.c @@ -376,15 +376,9 @@ nxt_http_route_match_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, nxt_http_route_match_conf_t mtcf; static nxt_str_t pass_path = nxt_string("/action/pass"); + static nxt_str_t share_path = nxt_string("/action/share"); static nxt_str_t match_path = nxt_string("/match"); - pass_conf = nxt_conf_get_path(cv, &pass_path); - if (nxt_slow_path(pass_conf == NULL)) { - return NULL; - } - - nxt_conf_get_string(pass_conf, &pass); - match_conf = nxt_conf_get_path(cv, &match_path); n = (match_conf != NULL) ? nxt_conf_object_members_count(match_conf) : 0; @@ -401,6 +395,19 @@ nxt_http_route_match_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, match->pass.handler = NULL; match->items = n; + pass_conf = nxt_conf_get_path(cv, &pass_path); + + if (pass_conf == NULL) { + pass_conf = nxt_conf_get_path(cv, &share_path); + if (nxt_slow_path(pass_conf == NULL)) { + return NULL; + } + + match->pass.handler = nxt_http_static_handler; + } + + nxt_conf_get_string(pass_conf, &pass); + string = nxt_str_dup(mp, &match->pass.name, &pass); if (nxt_slow_path(string == NULL)) { return NULL; @@ -870,13 +877,18 @@ static void nxt_http_route_resolve(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, nxt_http_route_t *route) { + nxt_http_pass_t *pass; nxt_http_route_match_t **match, **end; match = &route->match[0]; end = match + route->items; while (match < end) { - nxt_http_pass_resolve(task, tmcf, &(*match)->pass); + pass = &(*match)->pass; + + if (pass->handler == NULL) { + nxt_http_pass_resolve(task, tmcf, &(*match)->pass); + } match++; } diff --git a/src/nxt_http_static.c b/src/nxt_http_static.c new file mode 100644 index 00000000..44b85389 --- /dev/null +++ b/src/nxt_http_static.c @@ -0,0 +1,599 @@ + +/* + * Copyright (C) NGINX, Inc. + */ + +#include <nxt_router.h> +#include <nxt_http.h> + + +#define NXT_HTTP_STATIC_BUF_COUNT 2 +#define NXT_HTTP_STATIC_BUF_SIZE (128 * 1024) + + +static void nxt_http_static_extract_extension(nxt_str_t *path, + nxt_str_t *extension); +static void nxt_http_static_body_handler(nxt_task_t *task, void *obj, + void *data); +static void nxt_http_static_buf_completion(nxt_task_t *task, void *obj, + void *data); + +static nxt_int_t nxt_http_static_mtypes_hash_test(nxt_lvlhsh_query_t *lhq, + void *data); +static void *nxt_http_static_mtypes_hash_alloc(void *data, size_t size); +static void nxt_http_static_mtypes_hash_free(void *data, void *p); + + +static const nxt_http_request_state_t nxt_http_static_send_state; + + +nxt_http_pass_t * +nxt_http_static_handler(nxt_task_t *task, nxt_http_request_t *r, + nxt_http_pass_t *pass) +{ + size_t alloc, encode; + u_char *p; + struct tm tm; + nxt_buf_t *fb; + nxt_int_t ret; + nxt_str_t index, extension, *mtype; + nxt_uint_t level; + nxt_bool_t need_body; + nxt_file_t *f; + nxt_file_info_t fi; + nxt_http_field_t *field; + nxt_http_status_t status; + nxt_router_conf_t *rtcf; + nxt_work_handler_t body_handler; + + if (nxt_slow_path(!nxt_str_eq(r->method, "GET", 3))) { + + if (!nxt_str_eq(r->method, "HEAD", 4)) { + nxt_http_request_error(task, r, NXT_HTTP_METHOD_NOT_ALLOWED); + return NULL; + } + + need_body = 0; + + } else { + need_body = 1; + } + + f = nxt_mp_zget(r->mem_pool, sizeof(nxt_file_t)); + if (nxt_slow_path(f == NULL)) { + goto fail; + } + + f->fd = NXT_FILE_INVALID; + + if (r->path->start[r->path->length - 1] == '/') { + /* TODO: dynamic index setting. */ + nxt_str_set(&index, "index.html"); + nxt_str_set(&extension, ".html"); + + } else { + nxt_str_null(&index); + nxt_str_null(&extension); + } + + alloc = pass->name.length + r->path->length + index.length + 1; + + f->name = nxt_mp_nget(r->mem_pool, alloc); + if (nxt_slow_path(f->name == NULL)) { + goto fail; + } + + p = f->name; + p = nxt_cpymem(p, pass->name.start, pass->name.length); + p = nxt_cpymem(p, r->path->start, r->path->length); + p = nxt_cpymem(p, index.start, index.length); + *p = '\0'; + + ret = nxt_file_open(task, f, NXT_FILE_RDONLY, NXT_FILE_OPEN, 0); + + if (nxt_slow_path(ret != NXT_OK)) { + switch (f->error) { + + case NXT_ENOENT: + case NXT_ENOTDIR: + case NXT_ENAMETOOLONG: + level = NXT_LOG_ERR; + status = NXT_HTTP_NOT_FOUND; + break; + + case NXT_EACCES: + level = NXT_LOG_ERR; + status = NXT_HTTP_FORBIDDEN; + break; + + default: + level = NXT_LOG_ALERT; + status = NXT_HTTP_INTERNAL_SERVER_ERROR; + break; + } + + if (status != NXT_HTTP_NOT_FOUND) { + nxt_log(task, level, "open(\"%FN\") failed %E", f->name, f->error); + } + + nxt_http_request_error(task, r, status); + return NULL; + } + + ret = nxt_file_info(f, &fi); + if (nxt_slow_path(ret != NXT_OK)) { + goto fail; + } + + if (nxt_fast_path(nxt_is_file(&fi))) { + r->status = NXT_HTTP_OK; + r->resp.content_length_n = nxt_file_size(&fi); + + field = nxt_list_zero_add(r->resp.fields); + if (nxt_slow_path(field == NULL)) { + goto fail; + } + + nxt_http_field_name_set(field, "Last-Modified"); + + p = nxt_mp_nget(r->mem_pool, NXT_HTTP_DATE_LEN); + if (nxt_slow_path(p == NULL)) { + goto fail; + } + + nxt_localtime(nxt_file_mtime(&fi), &tm); + + field->value = p; + field->value_length = nxt_http_date(p, &tm) - p; + + field = nxt_list_zero_add(r->resp.fields); + if (nxt_slow_path(field == NULL)) { + goto fail; + } + + nxt_http_field_name_set(field, "ETag"); + + alloc = NXT_TIME_T_HEXLEN + NXT_OFF_T_HEXLEN + 3; + + p = nxt_mp_nget(r->mem_pool, alloc); + if (nxt_slow_path(p == NULL)) { + goto fail; + } + + field->value = p; + field->value_length = nxt_sprintf(p, p + alloc, "\"%xT-%xO\"", + nxt_file_mtime(&fi), + nxt_file_size(&fi)) + - p; + + if (extension.start == NULL) { + nxt_http_static_extract_extension(r->path, &extension); + } + + rtcf = r->conf->socket_conf->router_conf; + + mtype = nxt_http_static_mtypes_hash_find(&rtcf->mtypes_hash, + &extension); + + if (mtype != NULL) { + field = nxt_list_zero_add(r->resp.fields); + if (nxt_slow_path(field == NULL)) { + goto fail; + } + + nxt_http_field_name_set(field, "Content-Type"); + + field->value = mtype->start; + field->value_length = mtype->length; + } + + if (need_body && nxt_file_size(&fi) > 0) { + fb = nxt_mp_zget(r->mem_pool, NXT_BUF_FILE_SIZE); + if (nxt_slow_path(fb == NULL)) { + goto fail; + } + + fb->file = f; + fb->file_end = nxt_file_size(&fi); + + r->out = fb; + + body_handler = &nxt_http_static_body_handler; + + } else { + nxt_file_close(task, f); + body_handler = NULL; + } + + } else { + /* Not a file. */ + + nxt_file_close(task, f); + f = NULL; + + if (nxt_slow_path(!nxt_is_dir(&fi))) { + nxt_log(task, NXT_LOG_ERR, "\"%FN\" is not a regular file", + f->name); + nxt_http_request_error(task, r, NXT_HTTP_NOT_FOUND); + return NULL; + } + + r->status = NXT_HTTP_MOVED_PERMANENTLY; + r->resp.content_length_n = 0; + + field = nxt_list_zero_add(r->resp.fields); + if (nxt_slow_path(field == NULL)) { + goto fail; + } + + nxt_http_field_name_set(field, "Location"); + + encode = nxt_encode_uri(NULL, r->path->start, r->path->length); + alloc = r->path->length + encode * 2 + 1; + + if (r->args->length > 0) { + alloc += 1 + r->args->length; + } + + p = nxt_mp_nget(r->mem_pool, alloc); + if (nxt_slow_path(p == NULL)) { + goto fail; + } + + field->value = p; + field->value_length = alloc; + + if (encode > 0) { + p = (u_char *) nxt_encode_uri(p, r->path->start, r->path->length); + + } else { + p = nxt_cpymem(p, r->path->start, r->path->length); + } + + *p++ = '/'; + + if (r->args->length > 0) { + *p++ = '?'; + nxt_memcpy(p, r->args->start, r->args->length); + } + + body_handler = NULL; + } + + nxt_http_request_header_send(task, r, body_handler); + + r->state = &nxt_http_static_send_state; + return NULL; + +fail: + + nxt_http_request_error(task, r, NXT_HTTP_INTERNAL_SERVER_ERROR); + + if (f != NULL && f->fd != NXT_FILE_INVALID) { + nxt_file_close(task, f); + } + + return NULL; +} + + +static void +nxt_http_static_extract_extension(nxt_str_t *path, nxt_str_t *extension) +{ + u_char ch, *p, *end; + + end = path->start + path->length; + p = end; + + for ( ;; ) { + /* There's always '/' in the beginning of the request path. */ + + p--; + ch = *p; + + switch (ch) { + case '/': + p++; + /* Fall through. */ + case '.': + extension->length = end - p; + extension->start = p; + return; + } + } +} + + +static void +nxt_http_static_body_handler(nxt_task_t *task, void *obj, void *data) +{ + size_t alloc; + nxt_buf_t *fb, *b, **next, *out; + nxt_off_t rest; + nxt_int_t n; + nxt_work_queue_t *wq; + nxt_http_request_t *r; + + r = obj; + fb = r->out; + + rest = fb->file_end - fb->file_pos; + out = NULL; + next = &out; + n = 0; + + do { + alloc = nxt_min(rest, NXT_HTTP_STATIC_BUF_SIZE); + + b = nxt_buf_mem_alloc(r->mem_pool, alloc, 0); + if (nxt_slow_path(b == NULL)) { + goto fail; + } + + b->completion_handler = nxt_http_static_buf_completion; + b->parent = r; + + nxt_mp_retain(r->mem_pool); + + *next = b; + next = &b->next; + + rest -= alloc; + + } while (rest > 0 && ++n < NXT_HTTP_STATIC_BUF_COUNT); + + wq = &task->thread->engine->fast_work_queue; + + nxt_sendbuf_drain(task, wq, out); + return; + +fail: + + while (out != NULL) { + b = out; + out = b->next; + + nxt_mp_free(r->mem_pool, b); + nxt_mp_release(r->mem_pool); + } +} + + +static const nxt_http_request_state_t nxt_http_static_send_state + nxt_aligned(64) = +{ + .error_handler = nxt_http_request_error_handler, +}; + + +static void +nxt_http_static_buf_completion(nxt_task_t *task, void *obj, void *data) +{ + ssize_t n, size; + nxt_buf_t *b, *fb; + nxt_off_t rest; + nxt_http_request_t *r; + + b = obj; + r = data; + fb = r->out; + + if (nxt_slow_path(fb == NULL || r->error)) { + goto clean; + } + + rest = fb->file_end - fb->file_pos; + size = nxt_buf_mem_size(&b->mem); + + size = nxt_min(rest, (nxt_off_t) size); + + n = nxt_file_read(fb->file, b->mem.start, size, fb->file_pos); + + if (n != size) { + if (n >= 0) { + nxt_log(task, NXT_LOG_ERR, "file \"%FN\" has changed " + "while sending response to a client", fb->file->name); + } + + nxt_http_request_error_handler(task, r, r->proto.any); + goto clean; + } + + if (n == rest) { + nxt_file_close(task, fb->file); + r->out = NULL; + + b->next = nxt_http_buf_last(r); + + } else { + fb->file_pos += n; + b->next = NULL; + } + + b->mem.pos = b->mem.start; + b->mem.free = b->mem.pos + n; + + nxt_http_request_send(task, r, b); + return; + +clean: + + nxt_mp_free(r->mem_pool, b); + nxt_mp_release(r->mem_pool); + + if (fb != NULL) { + nxt_file_close(task, fb->file); + r->out = NULL; + } +} + + +nxt_int_t +nxt_http_static_mtypes_init(nxt_mp_t *mp, nxt_lvlhsh_t *hash) +{ + nxt_str_t *type, extension; + nxt_int_t ret; + nxt_uint_t i; + + static const struct { + nxt_str_t type; + const char *extension; + } default_types[] = { + + { nxt_string("text/html"), ".html" }, + { nxt_string("text/html"), ".htm" }, + { nxt_string("text/css"), ".css" }, + + { nxt_string("image/svg+xml"), ".svg" }, + { nxt_string("image/svg+xml"), ".svg" }, + { nxt_string("image/webp"), ".webp" }, + { nxt_string("image/png"), ".png" }, + { nxt_string("image/jpeg"), ".jpeg" }, + { nxt_string("image/jpeg"), ".jpg" }, + { nxt_string("image/gif"), ".gif" }, + { nxt_string("image/x-icon"), ".ico" }, + + { nxt_string("font/woff"), ".woff" }, + { nxt_string("font/woff2"), ".woff2" }, + { nxt_string("font/otf"), ".otf" }, + { nxt_string("font/ttf"), ".ttf" }, + + { nxt_string("text/plain"), ".txt" }, + { nxt_string("text/markdown"), ".md" }, + { nxt_string("text/x-rst"), ".rst" }, + + { nxt_string("application/javascript"), ".js" }, + { nxt_string("application/json"), ".json" }, + { nxt_string("application/xml"), ".xml" }, + { nxt_string("application/rss+xml"), ".rss" }, + { nxt_string("application/atom+xml"), ".atom" }, + { nxt_string("application/pdf"), ".pdf" }, + + { nxt_string("application/zip"), ".zip" }, + + { nxt_string("audio/mpeg"), ".mp3" }, + { nxt_string("audio/ogg"), ".ogg" }, + { nxt_string("audio/midi"), ".midi" }, + { nxt_string("audio/midi"), ".mid" }, + { nxt_string("audio/flac"), ".flac" }, + { nxt_string("audio/aac"), ".aac" }, + { nxt_string("audio/wav"), ".wav" }, + + { nxt_string("video/mpeg"), ".mpeg" }, + { nxt_string("video/mpeg"), ".mpg" }, + { nxt_string("video/mp4"), ".mp4" }, + { nxt_string("video/webm"), ".webm" }, + { nxt_string("video/x-msvideo"), ".avi" }, + + { nxt_string("application/octet-stream"), ".exe" }, + { nxt_string("application/octet-stream"), ".bin" }, + { nxt_string("application/octet-stream"), ".dll" }, + { nxt_string("application/octet-stream"), ".iso" }, + { nxt_string("application/octet-stream"), ".img" }, + { nxt_string("application/octet-stream"), ".msi" }, + + { nxt_string("application/octet-stream"), ".deb" }, + { nxt_string("application/octet-stream"), ".rpm" }, + }; + + for (i = 0; i < nxt_nitems(default_types); i++) { + type = (nxt_str_t *) &default_types[i].type; + + extension.start = (u_char *) default_types[i].extension; + extension.length = nxt_strlen(extension.start); + + ret = nxt_http_static_mtypes_hash_add(mp, hash, &extension, type); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + } + + return NXT_OK; +} + + +static const nxt_lvlhsh_proto_t nxt_http_static_mtypes_hash_proto + nxt_aligned(64) = +{ + NXT_LVLHSH_DEFAULT, + nxt_http_static_mtypes_hash_test, + nxt_http_static_mtypes_hash_alloc, + nxt_http_static_mtypes_hash_free, +}; + + +typedef struct { + nxt_str_t extension; + nxt_str_t *type; +} nxt_http_static_mtype_t; + + +nxt_int_t +nxt_http_static_mtypes_hash_add(nxt_mp_t *mp, nxt_lvlhsh_t *hash, + nxt_str_t *extension, nxt_str_t *type) +{ + nxt_lvlhsh_query_t lhq; + nxt_http_static_mtype_t *mtype; + + mtype = nxt_mp_get(mp, sizeof(nxt_http_static_mtype_t)); + if (nxt_slow_path(mtype == NULL)) { + return NXT_ERROR; + } + + mtype->extension = *extension; + mtype->type = type; + + lhq.key = *extension; + lhq.key_hash = nxt_djb_hash_lowcase(lhq.key.start, lhq.key.length); + lhq.replace = 1; + lhq.value = mtype; + lhq.proto = &nxt_http_static_mtypes_hash_proto; + lhq.pool = mp; + + return nxt_lvlhsh_insert(hash, &lhq); +} + + +nxt_str_t * +nxt_http_static_mtypes_hash_find(nxt_lvlhsh_t *hash, nxt_str_t *extension) +{ + nxt_lvlhsh_query_t lhq; + nxt_http_static_mtype_t *mtype; + + lhq.key = *extension; + lhq.key_hash = nxt_djb_hash_lowcase(lhq.key.start, lhq.key.length); + lhq.proto = &nxt_http_static_mtypes_hash_proto; + + if (nxt_lvlhsh_find(hash, &lhq) == NXT_OK) { + mtype = lhq.value; + return mtype->type; + } + + return NULL; +} + + +static nxt_int_t +nxt_http_static_mtypes_hash_test(nxt_lvlhsh_query_t *lhq, void *data) +{ + nxt_http_static_mtype_t *mtype; + + mtype = data; + + return nxt_strcasestr_eq(&lhq->key, &mtype->extension) ? NXT_OK + : NXT_DECLINED; +} + + +static void * +nxt_http_static_mtypes_hash_alloc(void *data, size_t size) +{ + return nxt_mp_align(data, size, size); +} + + +static void +nxt_http_static_mtypes_hash_free(void *data, void *p) +{ + nxt_mp_free(data, p); +} diff --git a/src/nxt_java.c b/src/nxt_java.c index 3421d825..08e24595 100644 --- a/src/nxt_java.c +++ b/src/nxt_java.c @@ -13,6 +13,7 @@ #include <nxt_unit_field.h> #include <nxt_unit_request.h> #include <nxt_unit_response.h> +#include <nxt_unit_websocket.h> #include <java/nxt_jni.h> @@ -30,6 +31,8 @@ static nxt_int_t nxt_java_pre_init(nxt_task_t *task, nxt_common_app_conf_t *conf); static nxt_int_t nxt_java_init(nxt_task_t *task, nxt_common_app_conf_t *conf); static void nxt_java_request_handler(nxt_unit_request_info_t *req); +static void nxt_java_websocket_handler(nxt_unit_websocket_frame_t *ws); +static void nxt_java_close_handler(nxt_unit_request_info_t *req); static uint32_t compat[] = { NXT_VERNUM, NXT_DEBUG, @@ -347,6 +350,8 @@ nxt_java_init(nxt_task_t *task, nxt_common_app_conf_t *conf) nxt_unit_default_init(task, &java_init); java_init.callbacks.request_handler = nxt_java_request_handler; + java_init.callbacks.websocket_handler = nxt_java_websocket_handler; + java_init.callbacks.close_handler = nxt_java_close_handler; java_init.request_data_size = sizeof(nxt_java_request_data_t); java_init.data = &data; @@ -361,14 +366,14 @@ nxt_java_init(nxt_task_t *task, nxt_common_app_conf_t *conf) /* TODO report error */ } - nxt_unit_done(ctx); - nxt_java_stopContext(env, data.ctx); if ((*env)->ExceptionCheck(env)) { (*env)->ExceptionDescribe(env); } + nxt_unit_done(ctx); + (*jvm)->DestroyJavaVM(jvm); exit(0); @@ -454,8 +459,71 @@ nxt_java_request_handler(nxt_unit_request_info_t *req) data->buf = NULL; } + if (nxt_unit_response_is_websocket(req)) { + data->jreq = (*env)->NewGlobalRef(env, jreq); + data->jresp = (*env)->NewGlobalRef(env, jresp); + + } else { + nxt_unit_request_done(req, NXT_UNIT_OK); + } + (*env)->DeleteLocalRef(env, jresp); (*env)->DeleteLocalRef(env, jreq); +} + + +static void +nxt_java_websocket_handler(nxt_unit_websocket_frame_t *ws) +{ + void *b; + JNIEnv *env; + jobject jbuf; + nxt_java_data_t *java_data; + nxt_java_request_data_t *data; + + java_data = ws->req->unit->data; + env = java_data->env; + data = ws->req->data; + + b = malloc(ws->payload_len); + if (b != NULL) { + nxt_unit_websocket_read(ws, b, ws->payload_len); + + jbuf = (*env)->NewDirectByteBuffer(env, b, ws->payload_len); + if (jbuf != NULL) { + nxt_java_Request_websocket(env, data->jreq, jbuf, + ws->header->opcode, ws->header->fin); + + if ((*env)->ExceptionCheck(env)) { + (*env)->ExceptionDescribe(env); + (*env)->ExceptionClear(env); + } + + (*env)->DeleteLocalRef(env, jbuf); + } + + free(b); + } + + nxt_unit_websocket_done(ws); +} + + +static void +nxt_java_close_handler(nxt_unit_request_info_t *req) +{ + JNIEnv *env; + nxt_java_data_t *java_data; + nxt_java_request_data_t *data; + + java_data = req->unit->data; + env = java_data->env; + data = req->data; + + nxt_java_Request_close(env, data->jreq); + + (*env)->DeleteGlobalRef(env, data->jresp); + (*env)->DeleteGlobalRef(env, data->jreq); nxt_unit_request_done(req, NXT_UNIT_OK); } diff --git a/src/nxt_main.h b/src/nxt_main.h index 23c55002..0afebb96 100644 --- a/src/nxt_main.h +++ b/src/nxt_main.h @@ -57,6 +57,7 @@ typedef uint16_t nxt_port_id_t; #include <nxt_fiber.h> #include <nxt_thread.h> #include <nxt_process_type.h> +#include <nxt_capability.h> #include <nxt_process.h> #include <nxt_utf8.h> #include <nxt_file_name.h> diff --git a/src/nxt_main_process.c b/src/nxt_main_process.c index 40682eb9..44deb272 100644 --- a/src/nxt_main_process.c +++ b/src/nxt_main_process.c @@ -14,6 +14,10 @@ #include <nxt_cert.h> #endif +#ifdef NXT_LINUX +#include <linux/sched.h> +#endif + typedef struct { nxt_socket_t socket; @@ -68,6 +72,10 @@ static void nxt_main_port_conf_store_handler(nxt_task_t *task, static void nxt_main_port_access_log_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg); +static nxt_int_t nxt_init_set_isolation(nxt_task_t *task, + nxt_process_init_t *init, nxt_conf_value_t *isolation); +static nxt_int_t nxt_init_set_ns(nxt_task_t *task, + nxt_process_init_t *init, nxt_conf_value_t *ns); const nxt_sig_event_t nxt_main_process_signals[] = { nxt_event_signal(SIGHUP, nxt_main_process_signal_handler), @@ -134,6 +142,12 @@ static nxt_conf_map_t nxt_common_app_conf[] = { NXT_CONF_MAP_PTR, offsetof(nxt_common_app_conf_t, environment), }, + + { + nxt_string("isolation"), + NXT_CONF_MAP_PTR, + offsetof(nxt_common_app_conf_t, isolation), + } }; @@ -271,12 +285,11 @@ nxt_port_main_start_worker_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg) nxt_int_t ret; nxt_buf_t *b; nxt_port_t *port; + nxt_runtime_t *rt; nxt_app_type_t idx; nxt_conf_value_t *conf; nxt_common_app_conf_t app_conf; - static nxt_str_t nobody = nxt_string("nobody"); - ret = NXT_ERROR; mp = nxt_mp_create(1024, 128, 256, 32); @@ -311,7 +324,10 @@ nxt_port_main_start_worker_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg) goto failed; } - app_conf.user = nobody; + rt = task->thread->runtime; + + app_conf.user.start = (u_char*)rt->user_cred.user; + app_conf.user.length = nxt_strlen(rt->user_cred.user); ret = nxt_conf_map_object(mp, conf, nxt_common_app_conf, nxt_nitems(nxt_common_app_conf), &app_conf); @@ -458,6 +474,8 @@ nxt_main_start_controller_process(nxt_task_t *task, nxt_runtime_t *rt) return NXT_ERROR; } + nxt_memzero(init, sizeof(nxt_process_init_t)); + init->start = nxt_controller_start; init->name = "controller"; init->user_cred = &rt->user_cred; @@ -552,6 +570,8 @@ nxt_main_start_discovery_process(nxt_task_t *task, nxt_runtime_t *rt) return NXT_ERROR; } + nxt_memzero(init, sizeof(nxt_process_init_t)); + init->start = nxt_discovery_start; init->name = "discovery"; init->user_cred = &rt->user_cred; @@ -576,6 +596,8 @@ nxt_main_start_router_process(nxt_task_t *task, nxt_runtime_t *rt) return NXT_ERROR; } + nxt_memzero(init, sizeof(nxt_process_init_t)); + init->start = nxt_router_start; init->name = "router"; init->user_cred = &rt->user_cred; @@ -589,7 +611,6 @@ nxt_main_start_router_process(nxt_task_t *task, nxt_runtime_t *rt) return nxt_main_create_worker_process(task, rt, init); } - static nxt_int_t nxt_main_start_worker_process(nxt_task_t *task, nxt_runtime_t *rt, nxt_common_app_conf_t *app_conf, uint32_t stream) @@ -597,41 +618,72 @@ nxt_main_start_worker_process(nxt_task_t *task, nxt_runtime_t *rt, char *user, *group; u_char *title, *last, *end; size_t size; + nxt_int_t ret; nxt_process_init_t *init; size = sizeof(nxt_process_init_t) - + sizeof(nxt_user_cred_t) - + app_conf->user.length + 1 - + app_conf->group.length + 1 - + app_conf->name.length + sizeof("\"\" application"); + + app_conf->name.length + + sizeof("\"\" application"); + + if (rt->capabilities.setid) { + size += sizeof(nxt_user_cred_t) + + app_conf->user.length + 1 + + app_conf->group.length + 1; + } init = nxt_malloc(size); if (nxt_slow_path(init == NULL)) { return NXT_ERROR; } - init->user_cred = nxt_pointer_to(init, sizeof(nxt_process_init_t)); - user = nxt_pointer_to(init->user_cred, sizeof(nxt_user_cred_t)); + nxt_memzero(init, sizeof(nxt_process_init_t)); - nxt_memcpy(user, app_conf->user.start, app_conf->user.length); - last = nxt_pointer_to(user, app_conf->user.length); - *last++ = '\0'; + if (rt->capabilities.setid) { + init->user_cred = nxt_pointer_to(init, sizeof(nxt_process_init_t)); + user = nxt_pointer_to(init->user_cred, sizeof(nxt_user_cred_t)); - init->user_cred->user = user; + nxt_memcpy(user, app_conf->user.start, app_conf->user.length); + last = nxt_pointer_to(user, app_conf->user.length); + *last++ = '\0'; - if (app_conf->group.start != NULL) { - group = (char *) last; + init->user_cred->user = user; - nxt_memcpy(group, app_conf->group.start, app_conf->group.length); - last = nxt_pointer_to(group, app_conf->group.length); - *last++ = '\0'; + if (app_conf->group.start != NULL) { + group = (char *) last; + + nxt_memcpy(group, app_conf->group.start, app_conf->group.length); + last = nxt_pointer_to(group, app_conf->group.length); + *last++ = '\0'; + + } else { + group = NULL; + } + + ret = nxt_user_cred_get(task, init->user_cred, group); + if (ret != NXT_OK) { + return NXT_ERROR; + } } else { - group = NULL; - } + if (!nxt_str_eq(&app_conf->user, (u_char *) rt->user_cred.user, + nxt_strlen(rt->user_cred.user))) + { + nxt_alert(task, "cannot set user \"%V\" for app \"%V\": " + "missing capabilities", &app_conf->user, &app_conf->name); + return NXT_ERROR; + } - if (nxt_user_cred_get(task, init->user_cred, group) != NXT_OK) { - return NXT_ERROR; + if (app_conf->group.length > 0 + && !nxt_str_eq(&app_conf->group, (u_char *) rt->group, + nxt_strlen(rt->group))) + { + nxt_alert(task, "cannot set group \"%V\" for app \"%V\": " + "missing capabilities", &app_conf->group, + &app_conf->name); + return NXT_ERROR; + } + + last = nxt_pointer_to(init, sizeof(nxt_process_init_t)); } title = last; @@ -648,6 +700,11 @@ nxt_main_start_worker_process(nxt_task_t *task, nxt_runtime_t *rt, init->stream = stream; init->restart = NULL; + ret = nxt_init_set_isolation(task, init, app_conf->isolation); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + return nxt_main_create_worker_process(task, rt, init); } @@ -690,15 +747,18 @@ nxt_main_create_worker_process(nxt_task_t *task, nxt_runtime_t *rt, pid = nxt_process_create(task, process); - nxt_port_use(task, port, -1); - switch (pid) { case -1: + nxt_port_close(task, port); + nxt_port_use(task, port, -1); + return NXT_ERROR; case 0: /* A worker process, return to the event engine work queue loop. */ + nxt_port_use(task, port, -1); + return NXT_AGAIN; default: @@ -707,6 +767,8 @@ nxt_main_create_worker_process(nxt_task_t *task, nxt_runtime_t *rt, nxt_port_read_close(port); nxt_port_write_enable(task, port); + nxt_port_use(task, port, -1); + return NXT_OK; } } @@ -1241,7 +1303,7 @@ nxt_main_port_modules_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg) nxt_conf_value_t *conf, *root, *value; nxt_app_lang_module_t *lang; - static nxt_str_t root_path = nxt_string("/"); + static nxt_str_t root_path = nxt_string("/"); rt = task->thread->runtime; @@ -1433,3 +1495,105 @@ nxt_main_port_access_log_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg) msg->port_msg.stream, 0, NULL); } } + + +static nxt_int_t +nxt_init_set_isolation(nxt_task_t *task, nxt_process_init_t *init, + nxt_conf_value_t *isolation) +{ + nxt_int_t ret; + nxt_conf_value_t *object; + + static nxt_str_t nsname = nxt_string("namespaces"); + static nxt_str_t uidname = nxt_string("uidmap"); + static nxt_str_t gidname = nxt_string("gidmap"); + + if (isolation == NULL) { + return NXT_OK; + } + + object = nxt_conf_get_object_member(isolation, &nsname, NULL); + if (object != NULL) { + ret = nxt_init_set_ns(task, init, object); + if (ret != NXT_OK) { + return ret; + } + } + + object = nxt_conf_get_object_member(isolation, &uidname, NULL); + if (object != NULL) { + init->isolation.clone.uidmap = object; + } + + object = nxt_conf_get_object_member(isolation, &gidname, NULL); + if (object != NULL) { + init->isolation.clone.gidmap = object; + } + + return NXT_OK; +} + + +static nxt_int_t +nxt_init_set_ns(nxt_task_t *task, nxt_process_init_t *init, nxt_conf_value_t *namespaces) +{ + uint32_t index; + nxt_str_t name; + nxt_int_t flag; + nxt_conf_value_t *value; + + index = 0; + + while ((value = nxt_conf_next_object_member(namespaces, &name, &index)) != NULL) { + flag = 0; + +#if (NXT_HAVE_CLONE_NEWUSER) + if (nxt_str_eq(&name, "credential", 10)) { + flag = CLONE_NEWUSER; + } +#endif + +#if (NXT_HAVE_CLONE_NEWPID) + if (nxt_str_eq(&name, "pid", 3)) { + flag = CLONE_NEWPID; + } +#endif + +#if (NXT_HAVE_CLONE_NEWNET) + if (nxt_str_eq(&name, "network", 7)) { + flag = CLONE_NEWNET; + } +#endif + +#if (NXT_HAVE_CLONE_NEWUTS) + if (nxt_str_eq(&name, "uname", 5)) { + flag = CLONE_NEWUTS; + } +#endif + +#if (NXT_HAVE_CLONE_NEWNS) + if (nxt_str_eq(&name, "mount", 5)) { + flag = CLONE_NEWNS; + } +#endif + +#if (NXT_HAVE_CLONE_NEWCGROUP) + if (nxt_str_eq(&name, "cgroup", 6)) { + flag = CLONE_NEWCGROUP; + } +#endif + + if (!flag) { + nxt_alert(task, "unknown namespace flag: \"%V\"", &name); + return NXT_ERROR; + } + + if (nxt_conf_get_integer(value) == 0) { + continue; /* process shares everything by default */ + } + + init->isolation.clone.flags |= flag; + } + + return NXT_OK; +} diff --git a/src/nxt_port.c b/src/nxt_port.c index cef65cab..9029353a 100644 --- a/src/nxt_port.c +++ b/src/nxt_port.c @@ -546,7 +546,7 @@ nxt_port_use(nxt_task_t *task, nxt_port_t *port, int i) if (i < 0 && c == -i) { - if (task->thread->engine == port->engine) { + if (port->engine == NULL || task->thread->engine == port->engine) { nxt_port_release(task, port); return; diff --git a/src/nxt_process.c b/src/nxt_process.c index c4aef21c..638765a4 100644 --- a/src/nxt_process.c +++ b/src/nxt_process.c @@ -7,10 +7,16 @@ #include <nxt_main.h> #include <nxt_main_process.h> +#if (NXT_HAVE_CLONE) +#include <nxt_clone.h> +#endif + +#include <signal.h> static void nxt_process_start(nxt_task_t *task, nxt_process_t *process); static nxt_int_t nxt_user_groups_get(nxt_task_t *task, nxt_user_cred_t *uc); - +static nxt_int_t nxt_process_worker_setup(nxt_task_t *task, + nxt_process_t *process, int parentfd); /* A cached process pid. */ nxt_pid_t nxt_pid; @@ -34,84 +40,217 @@ nxt_bool_t nxt_proc_remove_notify_matrix[NXT_PROCESS_MAX][NXT_PROCESS_MAX] = { { 0, 0, 0, 1, 0 }, }; -nxt_pid_t -nxt_process_create(nxt_task_t *task, nxt_process_t *process) -{ - nxt_pid_t pid; + +static nxt_int_t +nxt_process_worker_setup(nxt_task_t *task, nxt_process_t *process, int parentfd) { + pid_t rpid, pid; + ssize_t n; + nxt_int_t parent_status; nxt_process_t *p; nxt_runtime_t *rt; + nxt_process_init_t *init; nxt_process_type_t ptype; - rt = task->thread->runtime; + pid = getpid(); + rpid = 0; + rt = task->thread->runtime; + init = process->init; - pid = fork(); + /* Setup the worker process. */ - switch (pid) { + n = read(parentfd, &rpid, sizeof(rpid)); + if (nxt_slow_path(n == -1 || n != sizeof(rpid))) { + nxt_alert(task, "failed to read real pid"); + return NXT_ERROR; + } - case -1: - nxt_alert(task, "fork() failed while creating \"%s\" %E", - process->init->name, nxt_errno); - break; + if (nxt_slow_path(rpid == 0)) { + nxt_alert(task, "failed to get real pid from parent"); + return NXT_ERROR; + } - case 0: - /* A child. */ - nxt_pid = getpid(); + nxt_pid = rpid; + + /* Clean inherited cached thread tid. */ + task->thread->tid = 0; + + process->pid = nxt_pid; + + if (nxt_pid != pid) { + nxt_debug(task, "app \"%s\" real pid %d", init->name, nxt_pid); + nxt_debug(task, "app \"%s\" isolated pid: %d", init->name, pid); + } - /* Clean inherited cached thread tid. */ - task->thread->tid = 0; + n = read(parentfd, &parent_status, sizeof(parent_status)); + if (nxt_slow_path(n == -1 || n != sizeof(parent_status))) { + nxt_alert(task, "failed to read parent status"); + return NXT_ERROR; + } - process->pid = nxt_pid; + if (nxt_slow_path(close(parentfd) == -1)) { + nxt_alert(task, "failed to close reader pipe fd"); + return NXT_ERROR; + } - ptype = process->init->type; + if (nxt_slow_path(parent_status != NXT_OK)) { + return parent_status; + } - nxt_port_reset_next_id(); + ptype = init->type; - nxt_event_engine_thread_adopt(task->thread->engine); + nxt_port_reset_next_id(); - /* Remove not ready processes */ - nxt_runtime_process_each(rt, p) { + nxt_event_engine_thread_adopt(task->thread->engine); - if (nxt_proc_conn_matrix[ptype][nxt_process_type(p)] == 0) { - nxt_debug(task, "remove not required process %PI", p->pid); + /* Remove not ready processes. */ + nxt_runtime_process_each(rt, p) { - nxt_process_close_ports(task, p); + if (nxt_proc_conn_matrix[ptype][nxt_process_type(p)] == 0) { + nxt_debug(task, "remove not required process %PI", p->pid); - continue; - } + nxt_process_close_ports(task, p); - if (!p->ready) { - nxt_debug(task, "remove not ready process %PI", p->pid); + continue; + } - nxt_process_close_ports(task, p); + if (!p->ready) { + nxt_debug(task, "remove not ready process %PI", p->pid); - continue; - } + nxt_process_close_ports(task, p); - nxt_port_mmaps_destroy(&p->incoming, 0); - nxt_port_mmaps_destroy(&p->outgoing, 0); + continue; + } - } nxt_runtime_process_loop; + nxt_port_mmaps_destroy(&p->incoming, 0); + nxt_port_mmaps_destroy(&p->outgoing, 0); - nxt_runtime_process_add(task, process); + } nxt_runtime_process_loop; - nxt_process_start(task, process); + nxt_runtime_process_add(task, process); - process->ready = 1; + nxt_process_start(task, process); - break; + process->ready = 1; - default: - /* A parent. */ - nxt_debug(task, "fork(\"%s\"): %PI", process->init->name, pid); + return NXT_OK; +} - process->pid = pid; - nxt_runtime_process_add(task, process); +nxt_pid_t +nxt_process_create(nxt_task_t *task, nxt_process_t *process) +{ + int pipefd[2]; + nxt_int_t ret; + nxt_pid_t pid; + nxt_process_init_t *init; - break; + if (nxt_slow_path(pipe(pipefd) == -1)) { + nxt_alert(task, "failed to create process pipe for passing rpid"); + return -1; + } + + init = process->init; + +#if (NXT_HAVE_CLONE) + pid = nxt_clone(SIGCHLD|init->isolation.clone.flags); +#else + pid = fork(); +#endif + + if (nxt_slow_path(pid < 0)) { +#if (NXT_HAVE_CLONE) + nxt_alert(task, "clone() failed while creating \"%s\" %E", + init->name, nxt_errno); +#else + nxt_alert(task, "fork() failed while creating \"%s\" %E", + init->name, nxt_errno); +#endif + + return pid; + } + + if (pid == 0) { + /* Child. */ + + if (nxt_slow_path(close(pipefd[1]) == -1)) { + nxt_alert(task, "failed to close writer pipe fd"); + return NXT_ERROR; + } + + ret = nxt_process_worker_setup(task, process, pipefd[0]); + if (nxt_slow_path(ret != NXT_OK)) { + exit(1); + } + + /* + * Explicitly return 0 to notice the caller function this is the child. + * The caller must return to the event engine work queue loop. + */ + return 0; + } + + /* Parent. */ + + if (nxt_slow_path(close(pipefd[0]) != 0)) { + nxt_alert(task, "failed to close pipe: %E", nxt_errno); + } + + /* + * At this point, the child process is blocked reading the + * pipe fd to get its real pid (rpid). + * + * If anything goes wrong now, we need to terminate the child + * process by sending a NXT_ERROR in the pipe. + */ + +#if (NXT_HAVE_CLONE) + nxt_debug(task, "clone(\"%s\"): %PI", init->name, pid); +#else + nxt_debug(task, "fork(\"%s\"): %PI", init->name, pid); +#endif + + if (nxt_slow_path(write(pipefd[1], &pid, sizeof(pid)) == -1)) { + nxt_alert(task, "failed to write real pid"); + goto fail_cleanup; + } + +#if (NXT_HAVE_CLONE_NEWUSER) + if ((init->isolation.clone.flags & CLONE_NEWUSER) == CLONE_NEWUSER) { + ret = nxt_clone_proc_map(task, pid, &init->isolation.clone); + if (nxt_slow_path(ret != NXT_OK)) { + goto fail_cleanup; + } + } +#endif + + ret = NXT_OK; + + if (nxt_slow_path(write(pipefd[1], &ret, sizeof(ret)) == -1)) { + nxt_alert(task, "failed to write status"); + goto fail_cleanup; } + process->pid = pid; + + nxt_runtime_process_add(task, process); + return pid; + +fail_cleanup: + + ret = NXT_ERROR; + + if (nxt_slow_path(write(pipefd[1], &ret, sizeof(ret)) == -1)) { + nxt_alert(task, "failed to write status"); + } + + if (nxt_slow_path(close(pipefd[1]) != 0)) { + nxt_alert(task, "failed to close pipe: %E", nxt_errno); + } + + waitpid(pid, NULL, 0); + + return -1; } @@ -133,22 +272,17 @@ nxt_process_start(nxt_task_t *task, nxt_process_t *process) nxt_process_title(task, "unit: %s", init->name); thread = task->thread; + rt = thread->runtime; nxt_random_init(&thread->random); - if (init->user_cred != NULL) { - /* - * Changing user credentials requires either root privileges - * or CAP_SETUID and CAP_SETGID capabilities on Linux. - */ + if (rt->capabilities.setid && init->user_cred != NULL) { ret = nxt_user_cred_set(task, init->user_cred); if (ret != NXT_OK) { goto fail; } } - rt = thread->runtime; - rt->type = init->type; engine = thread->engine; @@ -592,15 +726,8 @@ nxt_user_cred_set(nxt_task_t *task, nxt_user_cred_t *uc) uc->user, (uint64_t) uc->uid, (uint64_t) uc->base_gid); if (setgid(uc->base_gid) != 0) { - if (nxt_errno == NXT_EPERM) { - nxt_log(task, NXT_LOG_NOTICE, "setgid(%d) failed %E, ignored", - uc->base_gid, nxt_errno); - return NXT_OK; - - } else { - nxt_alert(task, "setgid(%d) failed %E", uc->base_gid, nxt_errno); - return NXT_ERROR; - } + nxt_alert(task, "setgid(%d) failed %E", uc->base_gid, nxt_errno); + return NXT_ERROR; } if (uc->gids != NULL) { diff --git a/src/nxt_process.h b/src/nxt_process.h index c6e19f97..df9ca038 100644 --- a/src/nxt_process.h +++ b/src/nxt_process.h @@ -7,6 +7,8 @@ #ifndef _NXT_PROCESS_H_INCLUDED_ #define _NXT_PROCESS_H_INCLUDED_ +#include <nxt_conf.h> + typedef pid_t nxt_pid_t; typedef uid_t nxt_uid_t; @@ -21,26 +23,35 @@ typedef struct { nxt_gid_t *gids; } nxt_user_cred_t; +typedef struct { + nxt_int_t flags; + nxt_conf_value_t *uidmap; + nxt_conf_value_t *gidmap; +} nxt_process_clone_t; + typedef struct nxt_process_init_s nxt_process_init_t; typedef nxt_int_t (*nxt_process_start_t)(nxt_task_t *task, void *data); typedef nxt_int_t (*nxt_process_restart_t)(nxt_task_t *task, nxt_runtime_t *rt, nxt_process_init_t *init); - struct nxt_process_init_s { - nxt_process_start_t start; - const char *name; - nxt_user_cred_t *user_cred; + nxt_process_start_t start; + const char *name; + nxt_user_cred_t *user_cred; + + nxt_port_handlers_t *port_handlers; + const nxt_sig_event_t *signals; - nxt_port_handlers_t *port_handlers; - const nxt_sig_event_t *signals; + nxt_process_type_t type; - nxt_process_type_t type; + void *data; + uint32_t stream; - void *data; - uint32_t stream; + nxt_process_restart_t restart; - nxt_process_restart_t restart; + union { + nxt_process_clone_t clone; + } isolation; }; diff --git a/src/nxt_router.c b/src/nxt_router.c index b87f588f..28781600 100644 --- a/src/nxt_router.c +++ b/src/nxt_router.c @@ -122,6 +122,8 @@ static void nxt_router_conf_send(nxt_task_t *task, static nxt_int_t nxt_router_conf_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, u_char *start, u_char *end); +static nxt_int_t nxt_router_conf_process_static(nxt_task_t *task, + nxt_router_conf_t *rtcf, nxt_conf_value_t *conf); static nxt_app_t *nxt_router_app_find(nxt_queue_t *queue, nxt_str_t *name); static void nxt_router_listen_socket_rpc_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, nxt_socket_conf_t *skcf); @@ -247,7 +249,7 @@ static nxt_int_t nxt_router_http_request_done(nxt_task_t *task, static void nxt_router_http_request_release(nxt_task_t *task, void *obj, void *data); -const nxt_http_request_state_t nxt_http_websocket; +extern const nxt_http_request_state_t nxt_http_websocket; static nxt_router_t *nxt_router; @@ -1399,7 +1401,7 @@ nxt_router_conf_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, nxt_conf_value_t *conf, *http, *value, *websocket; nxt_conf_value_t *applications, *application; nxt_conf_value_t *listeners, *listener; - nxt_conf_value_t *routes_conf; + nxt_conf_value_t *routes_conf, *static_conf; nxt_socket_conf_t *skcf; nxt_http_routes_t *routes; nxt_event_engine_t *engine; @@ -1419,6 +1421,7 @@ nxt_router_conf_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, #if (NXT_TLS) static nxt_str_t certificate_path = nxt_string("/tls/certificate"); #endif + static nxt_str_t static_path = nxt_string("/settings/http/static"); static nxt_str_t websocket_path = nxt_string("/settings/http/websocket"); conf = nxt_conf_json_parse(tmcf->mem_pool, start, end, NULL); @@ -1440,6 +1443,13 @@ nxt_router_conf_create(nxt_task_t *task, nxt_router_temp_conf_t *tmcf, tmcf->router_conf->threads = nxt_ncpu; } + static_conf = nxt_conf_get_path(conf, &static_path); + + ret = nxt_router_conf_process_static(task, tmcf->router_conf, static_conf); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + router = tmcf->router_conf->router; applications = nxt_conf_get_path(conf, &applications_path); @@ -1788,6 +1798,87 @@ fail: } +static nxt_int_t +nxt_router_conf_process_static(nxt_task_t *task, nxt_router_conf_t *rtcf, + nxt_conf_value_t *conf) +{ + uint32_t next, i; + nxt_mp_t *mp; + nxt_str_t *type, extension, str; + nxt_int_t ret; + nxt_uint_t exts; + nxt_conf_value_t *mtypes_conf, *ext_conf, *value; + + static nxt_str_t mtypes_path = nxt_string("/mime_types"); + + mp = rtcf->mem_pool; + + ret = nxt_http_static_mtypes_init(mp, &rtcf->mtypes_hash); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + + if (conf == NULL) { + return NXT_OK; + } + + mtypes_conf = nxt_conf_get_path(conf, &mtypes_path); + + if (mtypes_conf != NULL) { + next = 0; + + for ( ;; ) { + ext_conf = nxt_conf_next_object_member(mtypes_conf, &str, &next); + + if (ext_conf == NULL) { + break; + } + + type = nxt_str_dup(mp, NULL, &str); + if (nxt_slow_path(type == NULL)) { + return NXT_ERROR; + } + + if (nxt_conf_type(ext_conf) == NXT_CONF_STRING) { + nxt_conf_get_string(ext_conf, &str); + + if (nxt_slow_path(nxt_str_dup(mp, &extension, &str) == NULL)) { + return NXT_ERROR; + } + + ret = nxt_http_static_mtypes_hash_add(mp, &rtcf->mtypes_hash, + &extension, type); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + + continue; + } + + exts = nxt_conf_array_elements_count(ext_conf); + + for (i = 0; i < exts; i++) { + value = nxt_conf_get_array_element(ext_conf, i); + + nxt_conf_get_string(value, &str); + + if (nxt_slow_path(nxt_str_dup(mp, &extension, &str) == NULL)) { + return NXT_ERROR; + } + + ret = nxt_http_static_mtypes_hash_add(mp, &rtcf->mtypes_hash, + &extension, type); + if (nxt_slow_path(ret != NXT_OK)) { + return NXT_ERROR; + } + } + } + } + + return NXT_OK; +} + + static nxt_app_t * nxt_router_app_find(nxt_queue_t *queue, nxt_str_t *name) { @@ -3531,7 +3622,7 @@ nxt_router_response_ready_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg, nxt_buf_chain_add(&r->out, b); } - nxt_http_request_header_send(task, r); + nxt_http_request_header_send(task, r, nxt_http_request_send_body); if (r->websocket_handshake && r->status == NXT_HTTP_SWITCHING_PROTOCOLS) @@ -3575,11 +3666,6 @@ nxt_router_response_ready_handler(nxt_task_t *task, nxt_port_recv_msg_t *msg, } else { r->state = &nxt_http_request_send_state; } - - if (r->out) { - nxt_work_queue_add(&task->thread->engine->fast_work_queue, - nxt_http_request_send_body, task, r, NULL); - } } return; diff --git a/src/nxt_router.h b/src/nxt_router.h index b55a4de3..ec18ff48 100644 --- a/src/nxt_router.h +++ b/src/nxt_router.h @@ -38,9 +38,13 @@ typedef struct { typedef struct { uint32_t count; uint32_t threads; + + nxt_mp_t *mem_pool; + nxt_router_t *router; nxt_http_routes_t *routes; - nxt_mp_t *mem_pool; + + nxt_lvlhsh_t mtypes_hash; nxt_router_access_log_t *access_log; } nxt_router_conf_t; diff --git a/src/nxt_runtime.c b/src/nxt_runtime.c index 06478f72..de41ba4d 100644 --- a/src/nxt_runtime.c +++ b/src/nxt_runtime.c @@ -692,14 +692,26 @@ nxt_runtime_conf_init(nxt_task_t *task, nxt_runtime_t *rt) rt->state = NXT_STATE; rt->control = NXT_CONTROL_SOCK; + nxt_memzero(&rt->capabilities, sizeof(nxt_capabilities_t)); + if (nxt_runtime_conf_read_cmd(task, rt) != NXT_OK) { return NXT_ERROR; } - if (nxt_user_cred_get(task, &rt->user_cred, rt->group) != NXT_OK) { + if (nxt_capability_set(task, &rt->capabilities) != NXT_OK) { return NXT_ERROR; } + if (rt->capabilities.setid) { + if (nxt_user_cred_get(task, &rt->user_cred, rt->group) != NXT_OK) { + return NXT_ERROR; + } + + } else { + nxt_log(task, NXT_LOG_WARN, "Unit is running unprivileged, then it " + "cannot use arbitrary user and group."); + } + /* An engine's parameters. */ interface = nxt_service_get(rt->services, "engine", rt->engine); diff --git a/src/nxt_runtime.h b/src/nxt_runtime.h index 496ae478..0791f8e7 100644 --- a/src/nxt_runtime.h +++ b/src/nxt_runtime.h @@ -59,6 +59,7 @@ struct nxt_runtime_s { uint32_t engine_connections; uint32_t auxiliary_threads; nxt_user_cred_t user_cred; + nxt_capabilities_t capabilities; const char *group; const char *pid; const char *log; diff --git a/src/nxt_sprintf.c b/src/nxt_sprintf.c index 240f47ef..9557b327 100644 --- a/src/nxt_sprintf.c +++ b/src/nxt_sprintf.c @@ -12,8 +12,8 @@ /* * Supported formats: * - * %[0][width][x][X]O nxt_off_t - * %[0][width]T nxt_time_t + * %[0][width][x|X]O nxt_off_t + * %[0][width][x|X]T nxt_time_t * %[0][width][u][x|X]z ssize_t/size_t * %[0][width][u][x|X]d int/u_int * %[0][width][u][x|X]l long diff --git a/src/nxt_string.c b/src/nxt_string.c index 7d8c1ce3..b89e9555 100644 --- a/src/nxt_string.c +++ b/src/nxt_string.c @@ -451,3 +451,125 @@ nxt_strvers_match(u_char *version, u_char *prefix, size_t length) return 0; } + + +u_char * +nxt_decode_uri(u_char *dst, u_char *src, size_t length) +{ + u_char *end, ch; + uint8_t d0, d1; + + static const uint8_t hex[256] + nxt_aligned(32) = + { + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 16, 16, 16, 16, 16, 16, + 16, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 10, 11, 12, 13, 14, 15, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, + }; + + nxt_prefetch(&hex['0']); + + end = src + length; + + while (src < end) { + ch = *src++; + + if (ch == '%') { + if (nxt_slow_path(end - src < 2)) { + return NULL; + } + + d0 = hex[*src++]; + d1 = hex[*src++]; + + if (nxt_slow_path((d0 | d1) >= 16)) { + return NULL; + } + + ch = (d0 << 4) + d1; + } + + *dst++ = ch; + } + + return dst; +} + + +uintptr_t +nxt_encode_uri(u_char *dst, u_char *src, size_t length) +{ + u_char *end; + nxt_uint_t n; + + static const u_char hex[16] = "0123456789ABCDEF"; + + /* " ", "#", "%", "?", %00-%1F, %7F-%FF */ + + static const uint32_t escape[] = { + 0xffffffff, /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + + /* ?>=< ;:98 7654 3210 /.-, +*)( '&%$ #"! */ + 0x80000029, /* 1000 0000 0000 0000 0000 0000 0010 1001 */ + + /* _^]\ [ZYX WVUT SRQP ONML KJIH GFED CBA@ */ + 0x00000000, /* 0000 0000 0000 0000 0000 0000 0000 0000 */ + + /* ~}| {zyx wvut srqp onml kjih gfed cba` */ + 0x80000000, /* 1000 0000 0000 0000 0000 0000 0000 0000 */ + + 0xffffffff, /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + 0xffffffff, /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + 0xffffffff, /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + 0xffffffff /* 1111 1111 1111 1111 1111 1111 1111 1111 */ + }; + + end = src + length; + + if (dst == NULL) { + + /* Find the number of the characters to be escaped. */ + + n = 0; + + while (src < end) { + + if (escape[*src >> 5] & (1U << (*src & 0x1f))) { + n++; + } + + src++; + } + + return (uintptr_t) n; + } + + while (src < end) { + + if (escape[*src >> 5] & (1U << (*src & 0x1f))) { + *dst++ = '%'; + *dst++ = hex[*src >> 4]; + *dst++ = hex[*src & 0xf]; + + } else { + *dst++ = *src; + } + + src++; + } + + return (uintptr_t) dst; +} diff --git a/src/nxt_string.h b/src/nxt_string.h index 22a63a17..8d7b3b73 100644 --- a/src/nxt_string.h +++ b/src/nxt_string.h @@ -168,5 +168,8 @@ NXT_EXPORT nxt_int_t nxt_strverscmp(const u_char *s1, const u_char *s2); NXT_EXPORT nxt_bool_t nxt_strvers_match(u_char *version, u_char *prefix, size_t length); +NXT_EXPORT u_char *nxt_decode_uri(u_char *dst, u_char *src, size_t length); +NXT_EXPORT uintptr_t nxt_encode_uri(u_char *dst, u_char *src, size_t length); + #endif /* _NXT_STRING_H_INCLUDED_ */ diff --git a/src/nxt_unit.c b/src/nxt_unit.c index 28a0de20..9ccd1fd9 100644 --- a/src/nxt_unit.c +++ b/src/nxt_unit.c @@ -31,7 +31,7 @@ typedef struct nxt_unit_request_info_impl_s nxt_unit_request_info_impl_t; typedef struct nxt_unit_websocket_frame_impl_s nxt_unit_websocket_frame_impl_t; static nxt_unit_impl_t *nxt_unit_create(nxt_unit_init_t *init); -static void nxt_unit_ctx_init(nxt_unit_impl_t *lib, +static int nxt_unit_ctx_init(nxt_unit_impl_t *lib, nxt_unit_ctx_impl_t *ctx_impl, void *data); nxt_inline void nxt_unit_mmap_buf_insert(nxt_unit_mmap_buf_t **head, nxt_unit_mmap_buf_t *mmap_buf); @@ -204,6 +204,8 @@ struct nxt_unit_websocket_frame_impl_s { struct nxt_unit_ctx_impl_s { nxt_unit_ctx_t ctx; + pthread_mutex_t mutex; + nxt_unit_port_id_t read_port_id; int read_port_fd; @@ -331,6 +333,7 @@ nxt_unit_init(nxt_unit_init_t *init) } } + lib->pid = read_port.id.pid; ctx = &lib->main_ctx.ctx; rc = lib->callbacks.add_port(ctx, &ready_port); @@ -396,13 +399,15 @@ nxt_unit_create(nxt_unit_init_t *init) lib->processes.slot = NULL; lib->ports.slot = NULL; - lib->pid = getpid(); lib->log_fd = STDERR_FILENO; lib->online = 1; nxt_queue_init(&lib->contexts); - nxt_unit_ctx_init(lib, &lib->main_ctx, init->ctx_data); + rc = nxt_unit_ctx_init(lib, &lib->main_ctx, init->ctx_data); + if (nxt_slow_path(rc != NXT_UNIT_OK)) { + goto fail; + } cb = &lib->callbacks; @@ -446,15 +451,24 @@ fail: } -static void +static int nxt_unit_ctx_init(nxt_unit_impl_t *lib, nxt_unit_ctx_impl_t *ctx_impl, void *data) { + int rc; + ctx_impl->ctx.data = data; ctx_impl->ctx.unit = &lib->unit; nxt_queue_insert_tail(&lib->contexts, &ctx_impl->link); + rc = pthread_mutex_init(&ctx_impl->mutex, NULL); + if (nxt_slow_path(rc != 0)) { + nxt_unit_alert(NULL, "failed to initialize mutex (%d)", rc); + + return NXT_UNIT_ERROR; + } + nxt_queue_init(&ctx_impl->free_req); nxt_queue_init(&ctx_impl->free_ws); nxt_queue_init(&ctx_impl->active_req); @@ -470,6 +484,8 @@ nxt_unit_ctx_init(nxt_unit_impl_t *lib, nxt_unit_ctx_impl_t *ctx_impl, ctx_impl->read_port_fd = -1; ctx_impl->requests.slot = 0; + + return NXT_UNIT_OK; } @@ -962,6 +978,11 @@ nxt_unit_process_websocket(nxt_unit_ctx_t *ctx, nxt_unit_recv_msg_t *recv_msg) } else { b = nxt_unit_mmap_buf_get(ctx); if (nxt_slow_path(b == NULL)) { + nxt_unit_alert(ctx, "#%"PRIu32": failed to allocate buf", + req_impl->stream); + + nxt_unit_websocket_frame_release(&ws_impl->ws); + return NXT_UNIT_ERROR; } @@ -1029,18 +1050,22 @@ nxt_unit_request_info_get(nxt_unit_ctx_t *ctx) lib = nxt_container_of(ctx->unit, nxt_unit_impl_t, unit); + pthread_mutex_lock(&ctx_impl->mutex); + if (nxt_queue_is_empty(&ctx_impl->free_req)) { + pthread_mutex_unlock(&ctx_impl->mutex); + req_impl = malloc(sizeof(nxt_unit_request_info_impl_t) + lib->request_data_size); if (nxt_slow_path(req_impl == NULL)) { - nxt_unit_warn(ctx, "request info allocation failed"); - return NULL; } req_impl->req.unit = ctx->unit; req_impl->req.ctx = ctx; + pthread_mutex_lock(&ctx_impl->mutex); + } else { lnk = nxt_queue_first(&ctx_impl->free_req); nxt_queue_remove(lnk); @@ -1050,6 +1075,8 @@ nxt_unit_request_info_get(nxt_unit_ctx_t *ctx) nxt_queue_insert_tail(&ctx_impl->active_req, &req_impl->link); + pthread_mutex_unlock(&ctx_impl->mutex); + req_impl->req.data = lib->request_data_size ? req_impl->extra_data : NULL; return req_impl; @@ -1068,12 +1095,6 @@ nxt_unit_request_info_release(nxt_unit_request_info_t *req) req->response = NULL; req->response_buf = NULL; - if (req_impl->process != NULL) { - nxt_unit_process_use(req->ctx, req_impl->process, -1); - - req_impl->process = NULL; - } - if (req_impl->websocket) { nxt_unit_request_hash_find(&ctx_impl->requests, req_impl->stream, 1); @@ -1088,10 +1109,24 @@ nxt_unit_request_info_release(nxt_unit_request_info_t *req) nxt_unit_mmap_buf_free(req_impl->incoming_buf); } + /* + * Process release should go after buffers release to guarantee mmap + * existence. + */ + if (req_impl->process != NULL) { + nxt_unit_process_use(req->ctx, req_impl->process, -1); + + req_impl->process = NULL; + } + + pthread_mutex_lock(&ctx_impl->mutex); + nxt_queue_remove(&req_impl->link); nxt_queue_insert_tail(&ctx_impl->free_req, &req_impl->link); + pthread_mutex_unlock(&ctx_impl->mutex); + req_impl->state = NXT_UNIT_RS_RELEASED; } @@ -1120,11 +1155,13 @@ nxt_unit_websocket_frame_get(nxt_unit_ctx_t *ctx) ctx_impl = nxt_container_of(ctx, nxt_unit_ctx_impl_t, ctx); + pthread_mutex_lock(&ctx_impl->mutex); + if (nxt_queue_is_empty(&ctx_impl->free_ws)) { + pthread_mutex_unlock(&ctx_impl->mutex); + ws_impl = malloc(sizeof(nxt_unit_websocket_frame_impl_t)); if (nxt_slow_path(ws_impl == NULL)) { - nxt_unit_warn(ctx, "websocket frame allocation failed"); - return NULL; } @@ -1132,6 +1169,8 @@ nxt_unit_websocket_frame_get(nxt_unit_ctx_t *ctx) lnk = nxt_queue_first(&ctx_impl->free_ws); nxt_queue_remove(lnk); + pthread_mutex_unlock(&ctx_impl->mutex); + ws_impl = nxt_container_of(lnk, nxt_unit_websocket_frame_impl_t, link); } @@ -1160,7 +1199,11 @@ nxt_unit_websocket_frame_release(nxt_unit_websocket_frame_t *ws) ws_impl->retain_buf = NULL; } + pthread_mutex_lock(&ws_impl->ctx_impl->mutex); + nxt_queue_insert_tail(&ws_impl->ctx_impl->free_ws, &ws_impl->link); + + pthread_mutex_unlock(&ws_impl->ctx_impl->mutex); } @@ -1635,6 +1678,8 @@ nxt_unit_response_buf_alloc(nxt_unit_request_info_t *req, uint32_t size) mmap_buf = nxt_unit_mmap_buf_get(req->ctx); if (nxt_slow_path(mmap_buf == NULL)) { + nxt_unit_req_alert(req, "response_buf_alloc: failed to allocate buf"); + return NULL; } @@ -1688,16 +1733,22 @@ nxt_unit_mmap_buf_get(nxt_unit_ctx_t *ctx) ctx_impl = nxt_container_of(ctx, nxt_unit_ctx_impl_t, ctx); + pthread_mutex_lock(&ctx_impl->mutex); + if (ctx_impl->free_buf == NULL) { + pthread_mutex_unlock(&ctx_impl->mutex); + mmap_buf = malloc(sizeof(nxt_unit_mmap_buf_t)); if (nxt_slow_path(mmap_buf == NULL)) { - nxt_unit_warn(ctx, "failed to allocate buf"); + return NULL; } } else { mmap_buf = ctx_impl->free_buf; nxt_unit_mmap_buf_remove(mmap_buf); + + pthread_mutex_unlock(&ctx_impl->mutex); } mmap_buf->ctx_impl = ctx_impl; @@ -1711,7 +1762,11 @@ nxt_unit_mmap_buf_release(nxt_unit_mmap_buf_t *mmap_buf) { nxt_unit_mmap_buf_remove(mmap_buf); + pthread_mutex_lock(&mmap_buf->ctx_impl->mutex); + nxt_unit_mmap_buf_insert(&mmap_buf->ctx_impl->free_buf, mmap_buf); + + pthread_mutex_unlock(&mmap_buf->ctx_impl->mutex); } @@ -3298,7 +3353,14 @@ nxt_unit_ctx_alloc(nxt_unit_ctx_t *ctx, void *data) close(fd); - nxt_unit_ctx_init(lib, new_ctx, data); + rc = nxt_unit_ctx_init(lib, new_ctx, data); + if (nxt_slow_path(rc != NXT_UNIT_OK)) { + lib->callbacks.remove_port(ctx, &new_port_id); + + free(new_ctx); + + return NULL; + } new_ctx->read_port_id = new_port_id; @@ -3350,6 +3412,8 @@ nxt_unit_ctx_free(nxt_unit_ctx_t *ctx) } nxt_queue_loop; + pthread_mutex_destroy(&ctx_impl->mutex); + nxt_queue_remove(&ctx_impl->link); if (ctx_impl != &lib->main_ctx) { diff --git a/src/test/nxt_http_parse_test.c b/src/test/nxt_http_parse_test.c index 572e91b2..5498cb1f 100644 --- a/src/test/nxt_http_parse_test.c +++ b/src/test/nxt_http_parse_test.c @@ -21,8 +21,6 @@ typedef struct { unsigned quoted_target:1; /* target with " " */ unsigned space_in_target:1; - /* target with "+" */ - unsigned plus_in_target:1; } nxt_http_parse_test_request_line_t; @@ -70,7 +68,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 0, 0, 0, 0 + 0, 0, 0 }} }, { @@ -80,10 +78,10 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { { .request_line = { nxt_string("XXX-METHOD"), nxt_string("/d.ir/fi+le.ext?key=val"), - nxt_string("ext?key=val"), + nxt_string("ext"), nxt_string("key=val"), "HTTP/1.2", - 0, 0, 0, 1 + 0, 0, 0 }} }, { @@ -96,7 +94,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_string(""), "HTTP/1.0", - 0, 0, 0, 0 + 0, 0, 0 }} }, { @@ -139,7 +137,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 1, 0, 0, 0 + 1, 0, 0 }} }, { @@ -152,7 +150,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 1, 0, 0, 0 + 1, 0, 0 }} }, { @@ -163,9 +161,9 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_string("GET"), nxt_string("/?#"), nxt_null_string, - nxt_string("#"), + nxt_string(""), "HTTP/1.0", - 1, 0, 0, 0 + 1, 0, 0 }} }, { @@ -178,7 +176,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 1, 0, 0, 0 + 1, 0, 0 }} }, { @@ -191,7 +189,7 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 0, 1, 0, 0 + 0, 1, 0 }} }, { @@ -204,7 +202,20 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { nxt_null_string, nxt_null_string, "HTTP/1.0", - 0, 0, 1, 0 + 0, 0, 1 + }} + }, + { + nxt_string("GET /na %20me.ext?args HTTP/1.0\r\n\r\n"), + NXT_DONE, + &nxt_http_parse_test_request_line, + { .request_line = { + nxt_string("GET"), + nxt_string("/na %20me.ext?args"), + nxt_string("ext"), + nxt_string("args"), + "HTTP/1.0", + 0, 1, 1 }} }, { @@ -214,10 +225,10 @@ static nxt_http_parse_test_case_t nxt_http_test_cases[] = { { .request_line = { nxt_string("GET"), nxt_string("/ HTTP/1.0"), - nxt_null_string, + nxt_string("0"), nxt_null_string, "HTTP/1.1", - 0, 0, 1, 0 + 0, 0, 1 }} }, { @@ -729,31 +740,23 @@ nxt_http_parse_test_request_line(nxt_http_request_parse_t *rp, return NXT_ERROR; } - str.length = (rp->exten_start != NULL) ? rp->target_end - rp->exten_start - : 0; - str.start = rp->exten_start; - - if (str.start != test->exten.start - && !nxt_strstr_eq(&str, &test->exten)) + if (rp->exten.start != test->exten.start + && !nxt_strstr_eq(&rp->exten, &test->exten)) { nxt_log_alert(log, "http parse test case failed:\n" " - request:\n\"%V\"\n" " - exten: \"%V\" (expected: \"%V\")", - request, &str, &test->exten); + request, &rp->exten, &test->exten); return NXT_ERROR; } - str.length = (rp->args_start != NULL) ? rp->target_end - rp->args_start - : 0; - str.start = rp->args_start; - - if (str.start != test->args.start - && !nxt_strstr_eq(&str, &test->args)) + if (rp->args.start != test->args.start + && !nxt_strstr_eq(&rp->args, &test->args)) { nxt_log_alert(log, "http parse test case failed:\n" " - request:\n\"%V\"\n" " - args: \"%V\" (expected: \"%V\")", - request, &str, &test->args); + request, &rp->args, &test->args); return NXT_ERROR; } @@ -790,14 +793,6 @@ nxt_http_parse_test_request_line(nxt_http_request_parse_t *rp, return NXT_ERROR; } - if (rp->plus_in_target != test->plus_in_target) { - nxt_log_alert(log, "http parse test case failed:\n" - " - request:\n\"%V\"\n" - " - plus_in_target: %d (expected: %d)", - request, rp->plus_in_target, test->plus_in_target); - return NXT_ERROR; - } - return NXT_OK; } |