/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package nginx.unit.websocket.server; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.Collections; import java.util.Enumeration; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Map.Entry; import javax.servlet.ServletException; import javax.servlet.ServletRequest; import javax.servlet.ServletResponse; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import javax.websocket.Endpoint; import javax.websocket.Extension; import javax.websocket.HandshakeResponse; import javax.websocket.server.ServerEndpointConfig; import nginx.unit.Request; import org.apache.tomcat.util.codec.binary.Base64; import org.apache.tomcat.util.res.StringManager; import org.apache.tomcat.util.security.ConcurrentMessageDigest; import nginx.unit.websocket.Constants; import nginx.unit.websocket.Transformation; import nginx.unit.websocket.TransformationFactory; import nginx.unit.websocket.Util; import nginx.unit.websocket.WsHandshakeResponse; import nginx.unit.websocket.pojo.PojoEndpointServer; public class UpgradeUtil { private static final StringManager sm = StringManager.getManager(UpgradeUtil.class.getPackage().getName()); private static final byte[] WS_ACCEPT = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11".getBytes( StandardCharsets.ISO_8859_1); private UpgradeUtil() { // Utility class. Hide default constructor. } /** * Checks to see if this is an HTTP request that includes a valid upgrade * request to web socket. *

* Note: RFC 2616 does not limit HTTP upgrade to GET requests but the Java * WebSocket spec 1.0, section 8.2 implies such a limitation and RFC * 6455 section 4.1 requires that a WebSocket Upgrade uses GET. * @param request The request to check if it is an HTTP upgrade request for * a WebSocket connection * @param response The response associated with the request * @return true if the request includes a HTTP Upgrade request * for the WebSocket protocol, otherwise false */ public static boolean isWebSocketUpgradeRequest(ServletRequest request, ServletResponse response) { Request r = (Request) request.getAttribute(Request.BARE); return ((request instanceof HttpServletRequest) && (response instanceof HttpServletResponse) && (r != null) && (r.isUpgrade())); } public static void doUpgrade(WsServerContainer sc, HttpServletRequest req, HttpServletResponse resp, ServerEndpointConfig sec, Map pathParams) throws ServletException, IOException { // Origin check String origin = req.getHeader(Constants.ORIGIN_HEADER_NAME); if (!sec.getConfigurator().checkOrigin(origin)) { resp.sendError(HttpServletResponse.SC_FORBIDDEN); return; } // Sub-protocols List subProtocols = getTokensFromHeader(req, Constants.WS_PROTOCOL_HEADER_NAME); String subProtocol = sec.getConfigurator().getNegotiatedSubprotocol( sec.getSubprotocols(), subProtocols); // Extensions // Should normally only be one header but handle the case of multiple // headers List extensionsRequested = new ArrayList<>(); Enumeration extHeaders = req.getHeaders(Constants.WS_EXTENSIONS_HEADER_NAME); while (extHeaders.hasMoreElements()) { Util.parseExtensionHeader(extensionsRequested, extHeaders.nextElement()); } // Negotiation phase 1. By default this simply filters out the // extensions that the server does not support but applications could // use a custom configurator to do more than this. List installedExtensions = null; if (sec.getExtensions().size() == 0) { installedExtensions = Constants.INSTALLED_EXTENSIONS; } else { installedExtensions = new ArrayList<>(); installedExtensions.addAll(sec.getExtensions()); installedExtensions.addAll(Constants.INSTALLED_EXTENSIONS); } List negotiatedExtensionsPhase1 = sec.getConfigurator().getNegotiatedExtensions( installedExtensions, extensionsRequested); // Negotiation phase 2. Create the Transformations that will be applied // to this connection. Note than an extension may be dropped at this // point if the client has requested a configuration that the server is // unable to support. List transformations = createTransformations(negotiatedExtensionsPhase1); List negotiatedExtensionsPhase2; if (transformations.isEmpty()) { negotiatedExtensionsPhase2 = Collections.emptyList(); } else { negotiatedExtensionsPhase2 = new ArrayList<>(transformations.size()); for (Transformation t : transformations) { negotiatedExtensionsPhase2.add(t.getExtensionResponse()); } } WsHttpUpgradeHandler wsHandler = req.upgrade(WsHttpUpgradeHandler.class); WsHandshakeRequest wsRequest = new WsHandshakeRequest(req, pathParams); WsHandshakeResponse wsResponse = new WsHandshakeResponse(); WsPerSessionServerEndpointConfig perSessionServerEndpointConfig = new WsPerSessionServerEndpointConfig(sec); sec.getConfigurator().modifyHandshake(perSessionServerEndpointConfig, wsRequest, wsResponse); //wsRequest.finished(); // Add any additional headers for (Entry> entry : wsResponse.getHeaders().entrySet()) { for (String headerValue: entry.getValue()) { resp.addHeader(entry.getKey(), headerValue); } } Endpoint ep; try { Class clazz = sec.getEndpointClass(); if (Endpoint.class.isAssignableFrom(clazz)) { ep = (Endpoint) sec.getConfigurator().getEndpointInstance( clazz); } else { ep = new PojoEndpointServer(); // Need to make path params available to POJO perSessionServerEndpointConfig.getUserProperties().put( nginx.unit.websocket.pojo.Constants.POJO_PATH_PARAM_KEY, pathParams); } } catch (InstantiationException e) { throw new ServletException(e); } wsHandler.preInit(ep, perSessionServerEndpointConfig, sc, wsRequest, negotiatedExtensionsPhase2, subProtocol, null, pathParams, req.isSecure()); wsHandler.init(null); } private static List createTransformations( List negotiatedExtensions) { TransformationFactory factory = TransformationFactory.getInstance(); LinkedHashMap>> extensionPreferences = new LinkedHashMap<>(); // Result will likely be smaller than this List result = new ArrayList<>(negotiatedExtensions.size()); for (Extension extension : negotiatedExtensions) { List> preferences = extensionPreferences.get(extension.getName()); if (preferences == null) { preferences = new ArrayList<>(); extensionPreferences.put(extension.getName(), preferences); } preferences.add(extension.getParameters()); } for (Map.Entry>> entry : extensionPreferences.entrySet()) { Transformation transformation = factory.create(entry.getKey(), entry.getValue(), true); if (transformation != null) { result.add(transformation); } } return result; } private static void append(StringBuilder sb, Extension extension) { if (extension == null || extension.getName() == null || extension.getName().length() == 0) { return; } sb.append(extension.getName()); for (Extension.Parameter p : extension.getParameters()) { sb.append(';'); sb.append(p.getName()); if (p.getValue() != null) { sb.append('='); sb.append(p.getValue()); } } } /* * This only works for tokens. Quoted strings need more sophisticated * parsing. */ private static boolean headerContainsToken(HttpServletRequest req, String headerName, String target) { Enumeration headers = req.getHeaders(headerName); while (headers.hasMoreElements()) { String header = headers.nextElement(); String[] tokens = header.split(","); for (String token : tokens) { if (target.equalsIgnoreCase(token.trim())) { return true; } } } return false; } /* * This only works for tokens. Quoted strings need more sophisticated * parsing. */ private static List getTokensFromHeader(HttpServletRequest req, String headerName) { List result = new ArrayList<>(); Enumeration headers = req.getHeaders(headerName); while (headers.hasMoreElements()) { String header = headers.nextElement(); String[] tokens = header.split(","); for (String token : tokens) { result.add(token.trim()); } } return result; } private static String getWebSocketAccept(String key) { byte[] digest = ConcurrentMessageDigest.digestSHA1( key.getBytes(StandardCharsets.ISO_8859_1), WS_ACCEPT); return Base64.encodeBase64String(digest); } }