diff options
Diffstat (limited to 'src/java/nginx/unit/websocket/server/UpgradeUtil.java')
-rw-r--r-- | src/java/nginx/unit/websocket/server/UpgradeUtil.java | 285 |
1 files changed, 285 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/server/UpgradeUtil.java b/src/java/nginx/unit/websocket/server/UpgradeUtil.java new file mode 100644 index 00000000..162f01c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/server/UpgradeUtil.java @@ -0,0 +1,285 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.server; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import javax.servlet.ServletException; +import javax.servlet.ServletRequest; +import javax.servlet.ServletResponse; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import javax.websocket.Endpoint; +import javax.websocket.Extension; +import javax.websocket.HandshakeResponse; +import javax.websocket.server.ServerEndpointConfig; + +import nginx.unit.Request; + +import org.apache.tomcat.util.codec.binary.Base64; +import org.apache.tomcat.util.res.StringManager; +import org.apache.tomcat.util.security.ConcurrentMessageDigest; +import nginx.unit.websocket.Constants; +import nginx.unit.websocket.Transformation; +import nginx.unit.websocket.TransformationFactory; +import nginx.unit.websocket.Util; +import nginx.unit.websocket.WsHandshakeResponse; +import nginx.unit.websocket.pojo.PojoEndpointServer; + +public class UpgradeUtil { + + private static final StringManager sm = + StringManager.getManager(UpgradeUtil.class.getPackage().getName()); + private static final byte[] WS_ACCEPT = + "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes( + StandardCharsets.ISO_8859_1); + + private UpgradeUtil() { + // Utility class. Hide default constructor. + } + + /** + * Checks to see if this is an HTTP request that includes a valid upgrade + * request to web socket. + * <p> + * Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java + * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC + * 6455 section 4.1 requires that a WebSocket Upgrade uses GET. + * @param request The request to check if it is an HTTP upgrade request for + * a WebSocket connection + * @param response The response associated with the request + * @return <code>true</code> if the request includes a HTTP Upgrade request + * for the WebSocket protocol, otherwise <code>false</code> + */ + public static boolean isWebSocketUpgradeRequest(ServletRequest request, + ServletResponse response) { + + Request r = (Request) request.getAttribute(Request.BARE); + + return ((request instanceof HttpServletRequest) && + (response instanceof HttpServletResponse) && + (r != null) && + (r.isUpgrade())); + } + + + public static void doUpgrade(WsServerContainer sc, HttpServletRequest req, + HttpServletResponse resp, ServerEndpointConfig sec, + Map<String,String> pathParams) + throws ServletException, IOException { + + + // Origin check + String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME); + + if (!sec.getConfigurator().checkOrigin(origin)) { + resp.sendError(HttpServletResponse.SC_FORBIDDEN); + return; + } + // Sub-protocols + List<String> subProtocols = getTokensFromHeader(req, + Constants.WS_PROTOCOL_HEADER_NAME); + String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol( + sec.getSubprotocols(), subProtocols); + + // Extensions + // Should normally only be one header but handle the case of multiple + // headers + List<Extension> extensionsRequested = new ArrayList<>(); + Enumeration<String> extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME); + while (extHeaders.hasMoreElements()) { + Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement()); + } + + // Negotiation phase 1. By default this simply filters out the + // extensions that the server does not support but applications could + // use a custom configurator to do more than this. + List<Extension> installedExtensions = null; + if (sec.getExtensions().size() == 0) { + installedExtensions = Constants.INSTALLED_EXTENSIONS; + } else { + installedExtensions = new ArrayList<>(); + installedExtensions.addAll(sec.getExtensions()); + installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS); + } + List<Extension> negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions( + installedExtensions, extensionsRequested); + + // Negotiation phase 2. Create the Transformations that will be applied + // to this connection. Note than an extension may be dropped at this + // point if the client has requested a configuration that the server is + // unable to support. + List<Transformation> transformations = createTransformations(negotiatedExtensionsPhase1); + + List<Extension> negotiatedExtensionsPhase2; + if (transformations.isEmpty()) { + negotiatedExtensionsPhase2 = Collections.emptyList(); + } else { + negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size()); + for (Transformation t : transformations) { + negotiatedExtensionsPhase2.add(t.getExtensionResponse()); + } + } + + WsHttpUpgradeHandler wsHandler = + req.upgrade(WsHttpUpgradeHandler.class); + + WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams); + WsHandshakeResponse wsResponse = new WsHandshakeResponse(); + WsPerSessionServerEndpointConfig perSessionServerEndpointConfig = + new WsPerSessionServerEndpointConfig(sec); + sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig, + wsRequest, wsResponse); + //wsRequest.finished(); + + // Add any additional headers + for (Entry<String,List<String>> entry : + wsResponse.getHeaders().entrySet()) { + for (String headerValue: entry.getValue()) { + resp.addHeader(entry.getKey(), headerValue); + } + } + + Endpoint ep; + try { + Class<?> clazz = sec.getEndpointClass(); + if (Endpoint.class.isAssignableFrom(clazz)) { + ep = (Endpoint) sec.getConfigurator().getEndpointInstance( + clazz); + } else { + ep = new PojoEndpointServer(); + // Need to make path params available to POJO + perSessionServerEndpointConfig.getUserProperties().put( + nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams); + } + } catch (InstantiationException e) { + throw new ServletException(e); + } + + wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest, + negotiatedExtensionsPhase2, subProtocol, null, pathParams, + req.isSecure()); + + wsHandler.init(null); + } + + + private static List<Transformation> createTransformations( + List<Extension> negotiatedExtensions) { + + TransformationFactory factory = TransformationFactory.getInstance(); + + LinkedHashMap<String,List<List<Extension.Parameter>>> extensionPreferences = + new LinkedHashMap<>(); + + // Result will likely be smaller than this + List<Transformation> result = new ArrayList<>(negotiatedExtensions.size()); + + for (Extension extension : negotiatedExtensions) { + List<List<Extension.Parameter>> preferences = + extensionPreferences.get(extension.getName()); + + if (preferences == null) { + preferences = new ArrayList<>(); + extensionPreferences.put(extension.getName(), preferences); + } + + preferences.add(extension.getParameters()); + } + + for (Map.Entry<String,List<List<Extension.Parameter>>> entry : + extensionPreferences.entrySet()) { + Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true); + if (transformation != null) { + result.add(transformation); + } + } + return result; + } + + + private static void append(StringBuilder sb, Extension extension) { + if (extension == null || extension.getName() == null || extension.getName().length() == 0) { + return; + } + + sb.append(extension.getName()); + + for (Extension.Parameter p : extension.getParameters()) { + sb.append(';'); + sb.append(p.getName()); + if (p.getValue() != null) { + sb.append('='); + sb.append(p.getValue()); + } + } + } + + + /* + * This only works for tokens. Quoted strings need more sophisticated + * parsing. + */ + private static boolean headerContainsToken(HttpServletRequest req, + String headerName, String target) { + Enumeration<String> headers = req.getHeaders(headerName); + while (headers.hasMoreElements()) { + String header = headers.nextElement(); + String[] tokens = header.split(","); + for (String token : tokens) { + if (target.equalsIgnoreCase(token.trim())) { + return true; + } + } + } + return false; + } + + + /* + * This only works for tokens. Quoted strings need more sophisticated + * parsing. + */ + private static List<String> getTokensFromHeader(HttpServletRequest req, + String headerName) { + List<String> result = new ArrayList<>(); + Enumeration<String> headers = req.getHeaders(headerName); + while (headers.hasMoreElements()) { + String header = headers.nextElement(); + String[] tokens = header.split(","); + for (String token : tokens) { + result.add(token.trim()); + } + } + return result; + } + + + private static String getWebSocketAccept(String key) { + byte[] digest = ConcurrentMessageDigest.digestSHA1( + key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT); + return Base64.encodeBase64String(digest); + } +} |