summaryrefslogtreecommitdiffhomepage
path: root/src/java/nginx/unit/websocket/server/UpgradeUtil.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/java/nginx/unit/websocket/server/UpgradeUtil.java')
-rw-r--r--src/java/nginx/unit/websocket/server/UpgradeUtil.java285
1 files changed, 285 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/server/UpgradeUtil.java b/src/java/nginx/unit/websocket/server/UpgradeUtil.java
new file mode 100644
index 00000000..162f01c7
--- /dev/null
+++ b/src/java/nginx/unit/websocket/server/UpgradeUtil.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package nginx.unit.websocket.server;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Enumeration;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+
+import javax.servlet.ServletException;
+import javax.servlet.ServletRequest;
+import javax.servlet.ServletResponse;
+import javax.servlet.http.HttpServletRequest;
+import javax.servlet.http.HttpServletResponse;
+import javax.websocket.Endpoint;
+import javax.websocket.Extension;
+import javax.websocket.HandshakeResponse;
+import javax.websocket.server.ServerEndpointConfig;
+
+import nginx.unit.Request;
+
+import org.apache.tomcat.util.codec.binary.Base64;
+import org.apache.tomcat.util.res.StringManager;
+import org.apache.tomcat.util.security.ConcurrentMessageDigest;
+import nginx.unit.websocket.Constants;
+import nginx.unit.websocket.Transformation;
+import nginx.unit.websocket.TransformationFactory;
+import nginx.unit.websocket.Util;
+import nginx.unit.websocket.WsHandshakeResponse;
+import nginx.unit.websocket.pojo.PojoEndpointServer;
+
+public class UpgradeUtil {
+
+ private static final StringManager sm =
+ StringManager.getManager(UpgradeUtil.class.getPackage().getName());
+ private static final byte[] WS_ACCEPT =
+ "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes(
+ StandardCharsets.ISO_8859_1);
+
+ private UpgradeUtil() {
+ // Utility class. Hide default constructor.
+ }
+
+ /**
+ * Checks to see if this is an HTTP request that includes a valid upgrade
+ * request to web socket.
+ * <p>
+ * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java
+ * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC
+ * 6455 section 4.1 requires that a WebSocket Upgrade uses GET.
+ * @param request The request to check if it is an HTTP upgrade request for
+ * a WebSocket connection
+ * @param response The response associated with the request
+ * @return <code>true</code> if the request includes a HTTP Upgrade request
+ * for the WebSocket protocol, otherwise <code>false</code>
+ */
+ public static boolean isWebSocketUpgradeRequest(ServletRequest request,
+ ServletResponse response) {
+
+ Request r = (Request) request.getAttribute(Request.BARE);
+
+ return ((request instanceof HttpServletRequest) &&
+ (response instanceof HttpServletResponse) &&
+ (r != null) &&
+ (r.isUpgrade()));
+ }
+
+
+ public static void doUpgrade(WsServerContainer sc, HttpServletRequest req,
+ HttpServletResponse resp, ServerEndpointConfig sec,
+ Map<String,String> pathParams)
+ throws ServletException, IOException {
+
+
+ // Origin check
+ String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME);
+
+ if (!sec.getConfigurator().checkOrigin(origin)) {
+ resp.sendError(HttpServletResponse.SC_FORBIDDEN);
+ return;
+ }
+ // Sub-protocols
+ List<String> subProtocols = getTokensFromHeader(req,
+ Constants.WS_PROTOCOL_HEADER_NAME);
+ String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol(
+ sec.getSubprotocols(), subProtocols);
+
+ // Extensions
+ // Should normally only be one header but handle the case of multiple
+ // headers
+ List<Extension> extensionsRequested = new ArrayList<>();
+ Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME);
+ while (extHeaders.hasMoreElements()) {
+ Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement());
+ }
+
+ // Negotiation phase 1. By default this simply filters out the
+ // extensions that the server does not support but applications could
+ // use a custom configurator to do more than this.
+ List<Extension> installedExtensions = null;
+ if (sec.getExtensions().size() == 0) {
+ installedExtensions = Constants.INSTALLED_EXTENSIONS;
+ } else {
+ installedExtensions = new ArrayList<>();
+ installedExtensions.addAll(sec.getExtensions());
+ installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS);
+ }
+ List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions(
+ installedExtensions, extensionsRequested);
+
+ // Negotiation phase 2. Create the Transformations that will be applied
+ // to this connection. Note than an extension may be dropped at this
+ // point if the client has requested a configuration that the server is
+ // unable to support.
+ List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1);
+
+ List<Extension> negotiatedExtensionsPhase2;
+ if (transformations.isEmpty()) {
+ negotiatedExtensionsPhase2 = Collections.emptyList();
+ } else {
+ negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size());
+ for (Transformation t : transformations) {
+ negotiatedExtensionsPhase2.add(t.getExtensionResponse());
+ }
+ }
+
+ WsHttpUpgradeHandler wsHandler =
+ req.upgrade(WsHttpUpgradeHandler.class);
+
+ WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams);
+ WsHandshakeResponse wsResponse = new WsHandshakeResponse();
+ WsPerSessionServerEndpointConfig perSessionServerEndpointConfig =
+ new WsPerSessionServerEndpointConfig(sec);
+ sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig,
+ wsRequest, wsResponse);
+ //wsRequest.finished();
+
+ // Add any additional headers
+ for (Entry<String,List<String>> entry :
+ wsResponse.getHeaders().entrySet()) {
+ for (String headerValue: entry.getValue()) {
+ resp.addHeader(entry.getKey(), headerValue);
+ }
+ }
+
+ Endpoint ep;
+ try {
+ Class<?> clazz = sec.getEndpointClass();
+ if (Endpoint.class.isAssignableFrom(clazz)) {
+ ep = (Endpoint) sec.getConfigurator().getEndpointInstance(
+ clazz);
+ } else {
+ ep = new PojoEndpointServer();
+ // Need to make path params available to POJO
+ perSessionServerEndpointConfig.getUserProperties().put(
+ nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams);
+ }
+ } catch (InstantiationException e) {
+ throw new ServletException(e);
+ }
+
+ wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest,
+ negotiatedExtensionsPhase2, subProtocol, null, pathParams,
+ req.isSecure());
+
+ wsHandler.init(null);
+ }
+
+
+ private static List<Transformation> createTransformations(
+ List<Extension> negotiatedExtensions) {
+
+ TransformationFactory factory = TransformationFactory.getInstance();
+
+ LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences =
+ new LinkedHashMap<>();
+
+ // Result will likely be smaller than this
+ List<Transformation> result = new ArrayList<>(negotiatedExtensions.size());
+
+ for (Extension extension : negotiatedExtensions) {
+ List<List<Extension.Parameter>> preferences =
+ extensionPreferences.get(extension.getName());
+
+ if (preferences == null) {
+ preferences = new ArrayList<>();
+ extensionPreferences.put(extension.getName(), preferences);
+ }
+
+ preferences.add(extension.getParameters());
+ }
+
+ for (Map.Entry<String,List<List<Extension.Parameter>>> entry :
+ extensionPreferences.entrySet()) {
+ Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true);
+ if (transformation != null) {
+ result.add(transformation);
+ }
+ }
+ return result;
+ }
+
+
+ private static void append(StringBuilder sb, Extension extension) {
+ if (extension == null || extension.getName() == null || extension.getName().length() == 0) {
+ return;
+ }
+
+ sb.append(extension.getName());
+
+ for (Extension.Parameter p : extension.getParameters()) {
+ sb.append(';');
+ sb.append(p.getName());
+ if (p.getValue() != null) {
+ sb.append('=');
+ sb.append(p.getValue());
+ }
+ }
+ }
+
+
+ /*
+ * This only works for tokens. Quoted strings need more sophisticated
+ * parsing.
+ */
+ private static boolean headerContainsToken(HttpServletRequest req,
+ String headerName, String target) {
+ Enumeration<String> headers = req.getHeaders(headerName);
+ while (headers.hasMoreElements()) {
+ String header = headers.nextElement();
+ String[] tokens = header.split(",");
+ for (String token : tokens) {
+ if (target.equalsIgnoreCase(token.trim())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+
+ /*
+ * This only works for tokens. Quoted strings need more sophisticated
+ * parsing.
+ */
+ private static List<String> getTokensFromHeader(HttpServletRequest req,
+ String headerName) {
+ List<String> result = new ArrayList<>();
+ Enumeration<String> headers = req.getHeaders(headerName);
+ while (headers.hasMoreElements()) {
+ String header = headers.nextElement();
+ String[] tokens = header.split(",");
+ for (String token : tokens) {
+ result.add(token.trim());
+ }
+ }
+ return result;
+ }
+
+
+ private static String getWebSocketAccept(String key) {
+ byte[] digest = ConcurrentMessageDigest.digestSHA1(
+ key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT);
+ return Base64.encodeBase64String(digest);
+ }
+}