diff options
Diffstat (limited to '')
-rw-r--r-- | src/java/nginx/unit/websocket/Util.java | 666 |
1 files changed, 666 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/Util.java b/src/java/nginx/unit/websocket/Util.java new file mode 100644 index 00000000..6acf3ade --- /dev/null +++ b/src/java/nginx/unit/websocket/Util.java @@ -0,0 +1,666 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket; + +import java.io.InputStream; +import java.io.Reader; +import java.lang.reflect.GenericArrayType; +import java.lang.reflect.Method; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.nio.ByteBuffer; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; + +import javax.websocket.CloseReason.CloseCode; +import javax.websocket.CloseReason.CloseCodes; +import javax.websocket.Decoder; +import javax.websocket.Decoder.Binary; +import javax.websocket.Decoder.BinaryStream; +import javax.websocket.Decoder.Text; +import javax.websocket.Decoder.TextStream; +import javax.websocket.DeploymentException; +import javax.websocket.Encoder; +import javax.websocket.EndpointConfig; +import javax.websocket.Extension; +import javax.websocket.MessageHandler; +import javax.websocket.PongMessage; +import javax.websocket.Session; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.pojo.PojoMessageHandlerPartialBinary; +import nginx.unit.websocket.pojo.PojoMessageHandlerWholeBinary; +import nginx.unit.websocket.pojo.PojoMessageHandlerWholeText; + +/** + * Utility class for internal use only within the + * {@link nginx.unit.websocket} package. + */ +public class Util { + + private static final StringManager sm = StringManager.getManager(Util.class); + private static final Queue<SecureRandom> randoms = + new ConcurrentLinkedQueue<>(); + + private Util() { + // Hide default constructor + } + + + static boolean isControl(byte opCode) { + return (opCode & 0x08) != 0; + } + + + static boolean isText(byte opCode) { + return opCode == Constants.OPCODE_TEXT; + } + + + static boolean isContinuation(byte opCode) { + return opCode == Constants.OPCODE_CONTINUATION; + } + + + static CloseCode getCloseCode(int code) { + if (code > 2999 && code < 5000) { + return CloseCodes.getCloseCode(code); + } + switch (code) { + case 1000: + return CloseCodes.NORMAL_CLOSURE; + case 1001: + return CloseCodes.GOING_AWAY; + case 1002: + return CloseCodes.PROTOCOL_ERROR; + case 1003: + return CloseCodes.CANNOT_ACCEPT; + case 1004: + // Should not be used in a close frame + // return CloseCodes.RESERVED; + return CloseCodes.PROTOCOL_ERROR; + case 1005: + // Should not be used in a close frame + // return CloseCodes.NO_STATUS_CODE; + return CloseCodes.PROTOCOL_ERROR; + case 1006: + // Should not be used in a close frame + // return CloseCodes.CLOSED_ABNORMALLY; + return CloseCodes.PROTOCOL_ERROR; + case 1007: + return CloseCodes.NOT_CONSISTENT; + case 1008: + return CloseCodes.VIOLATED_POLICY; + case 1009: + return CloseCodes.TOO_BIG; + case 1010: + return CloseCodes.NO_EXTENSION; + case 1011: + return CloseCodes.UNEXPECTED_CONDITION; + case 1012: + // Not in RFC6455 + // return CloseCodes.SERVICE_RESTART; + return CloseCodes.PROTOCOL_ERROR; + case 1013: + // Not in RFC6455 + // return CloseCodes.TRY_AGAIN_LATER; + return CloseCodes.PROTOCOL_ERROR; + case 1015: + // Should not be used in a close frame + // return CloseCodes.TLS_HANDSHAKE_FAILURE; + return CloseCodes.PROTOCOL_ERROR; + default: + return CloseCodes.PROTOCOL_ERROR; + } + } + + + static byte[] generateMask() { + // SecureRandom is not thread-safe so need to make sure only one thread + // uses it at a time. In theory, the pool could grow to the same size + // as the number of request processing threads. In reality it will be + // a lot smaller. + + // Get a SecureRandom from the pool + SecureRandom sr = randoms.poll(); + + // If one isn't available, generate a new one + if (sr == null) { + try { + sr = SecureRandom.getInstance("SHA1PRNG"); + } catch (NoSuchAlgorithmException e) { + // Fall back to platform default + sr = new SecureRandom(); + } + } + + // Generate the mask + byte[] result = new byte[4]; + sr.nextBytes(result); + + // Put the SecureRandom back in the poll + randoms.add(sr); + + return result; + } + + + static Class<?> getMessageType(MessageHandler listener) { + return Util.getGenericType(MessageHandler.class, + listener.getClass()).getClazz(); + } + + + private static Class<?> getDecoderType(Class<? extends Decoder> decoder) { + return Util.getGenericType(Decoder.class, decoder).getClazz(); + } + + + static Class<?> getEncoderType(Class<? extends Encoder> encoder) { + return Util.getGenericType(Encoder.class, encoder).getClazz(); + } + + + private static <T> TypeResult getGenericType(Class<T> type, + Class<? extends T> clazz) { + + // Look to see if this class implements the interface of interest + + // Get all the interfaces + Type[] interfaces = clazz.getGenericInterfaces(); + for (Type iface : interfaces) { + // Only need to check interfaces that use generics + if (iface instanceof ParameterizedType) { + ParameterizedType pi = (ParameterizedType) iface; + // Look for the interface of interest + if (pi.getRawType() instanceof Class) { + if (type.isAssignableFrom((Class<?>) pi.getRawType())) { + return getTypeParameter( + clazz, pi.getActualTypeArguments()[0]); + } + } + } + } + + // Interface not found on this class. Look at the superclass. + @SuppressWarnings("unchecked") + Class<? extends T> superClazz = + (Class<? extends T>) clazz.getSuperclass(); + if (superClazz == null) { + // Finished looking up the class hierarchy without finding anything + return null; + } + + TypeResult superClassTypeResult = getGenericType(type, superClazz); + int dimension = superClassTypeResult.getDimension(); + if (superClassTypeResult.getIndex() == -1 && dimension == 0) { + // Superclass implements interface and defines explicit type for + // the interface of interest + return superClassTypeResult; + } + + if (superClassTypeResult.getIndex() > -1) { + // Superclass implements interface and defines unknown type for + // the interface of interest + // Map that unknown type to the generic types defined in this class + ParameterizedType superClassType = + (ParameterizedType) clazz.getGenericSuperclass(); + TypeResult result = getTypeParameter(clazz, + superClassType.getActualTypeArguments()[ + superClassTypeResult.getIndex()]); + result.incrementDimension(superClassTypeResult.getDimension()); + if (result.getClazz() != null && result.getDimension() > 0) { + superClassTypeResult = result; + } else { + return result; + } + } + + if (superClassTypeResult.getDimension() > 0) { + StringBuilder className = new StringBuilder(); + for (int i = 0; i < dimension; i++) { + className.append('['); + } + className.append('L'); + className.append(superClassTypeResult.getClazz().getCanonicalName()); + className.append(';'); + + Class<?> arrayClazz; + try { + arrayClazz = Class.forName(className.toString()); + } catch (ClassNotFoundException e) { + throw new IllegalArgumentException(e); + } + + return new TypeResult(arrayClazz, -1, 0); + } + + // Error will be logged further up the call stack + return null; + } + + + /* + * For a generic parameter, return either the Class used or if the type + * is unknown, the index for the type in definition of the class + */ + private static TypeResult getTypeParameter(Class<?> clazz, Type argType) { + if (argType instanceof Class<?>) { + return new TypeResult((Class<?>) argType, -1, 0); + } else if (argType instanceof ParameterizedType) { + return new TypeResult((Class<?>)((ParameterizedType) argType).getRawType(), -1, 0); + } else if (argType instanceof GenericArrayType) { + Type arrayElementType = ((GenericArrayType) argType).getGenericComponentType(); + TypeResult result = getTypeParameter(clazz, arrayElementType); + result.incrementDimension(1); + return result; + } else { + TypeVariable<?>[] tvs = clazz.getTypeParameters(); + for (int i = 0; i < tvs.length; i++) { + if (tvs[i].equals(argType)) { + return new TypeResult(null, i, 0); + } + } + return null; + } + } + + + public static boolean isPrimitive(Class<?> clazz) { + if (clazz.isPrimitive()) { + return true; + } else if(clazz.equals(Boolean.class) || + clazz.equals(Byte.class) || + clazz.equals(Character.class) || + clazz.equals(Double.class) || + clazz.equals(Float.class) || + clazz.equals(Integer.class) || + clazz.equals(Long.class) || + clazz.equals(Short.class)) { + return true; + } + return false; + } + + + public static Object coerceToType(Class<?> type, String value) { + if (type.equals(String.class)) { + return value; + } else if (type.equals(boolean.class) || type.equals(Boolean.class)) { + return Boolean.valueOf(value); + } else if (type.equals(byte.class) || type.equals(Byte.class)) { + return Byte.valueOf(value); + } else if (type.equals(char.class) || type.equals(Character.class)) { + return Character.valueOf(value.charAt(0)); + } else if (type.equals(double.class) || type.equals(Double.class)) { + return Double.valueOf(value); + } else if (type.equals(float.class) || type.equals(Float.class)) { + return Float.valueOf(value); + } else if (type.equals(int.class) || type.equals(Integer.class)) { + return Integer.valueOf(value); + } else if (type.equals(long.class) || type.equals(Long.class)) { + return Long.valueOf(value); + } else if (type.equals(short.class) || type.equals(Short.class)) { + return Short.valueOf(value); + } else { + throw new IllegalArgumentException(sm.getString( + "util.invalidType", value, type.getName())); + } + } + + + public static List<DecoderEntry> getDecoders( + List<Class<? extends Decoder>> decoderClazzes) + throws DeploymentException{ + + List<DecoderEntry> result = new ArrayList<>(); + if (decoderClazzes != null) { + for (Class<? extends Decoder> decoderClazz : decoderClazzes) { + // Need to instantiate decoder to ensure it is valid and that + // deployment can be failed if it is not + @SuppressWarnings("unused") + Decoder instance; + try { + instance = decoderClazz.getConstructor().newInstance(); + } catch (ReflectiveOperationException e) { + throw new DeploymentException( + sm.getString("pojoMethodMapping.invalidDecoder", + decoderClazz.getName()), e); + } + DecoderEntry entry = new DecoderEntry( + Util.getDecoderType(decoderClazz), decoderClazz); + result.add(entry); + } + } + + return result; + } + + + static Set<MessageHandlerResult> getMessageHandlers(Class<?> target, + MessageHandler listener, EndpointConfig endpointConfig, + Session session) { + + // Will never be more than 2 types + Set<MessageHandlerResult> results = new HashSet<>(2); + + // Simple cases - handlers already accepts one of the types expected by + // the frame handling code + if (String.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.TEXT); + results.add(result); + } else if (ByteBuffer.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.BINARY); + results.add(result); + } else if (PongMessage.class.isAssignableFrom(target)) { + MessageHandlerResult result = + new MessageHandlerResult(listener, + MessageHandlerResultType.PONG); + results.add(result); + // Handler needs wrapping and optional decoder to convert it to one of + // the types expected by the frame handling code + } else if (byte[].class.isAssignableFrom(target)) { + boolean whole = MessageHandler.Whole.class.isAssignableFrom(listener.getClass()); + MessageHandlerResult result = new MessageHandlerResult( + whole ? new PojoMessageHandlerWholeBinary(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, true), + new Object[1], 0, true, -1, false, -1) : + new PojoMessageHandlerPartialBinary(listener, + getOnMessagePartialMethod(listener), session, + new Object[2], 0, true, 1, -1, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } else if (InputStream.class.isAssignableFrom(target)) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeBinary(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, true), + new Object[1], 0, true, -1, true, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } else if (Reader.class.isAssignableFrom(target)) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeText(listener, + getOnMessageMethod(listener), session, + endpointConfig, matchDecoders(target, endpointConfig, false), + new Object[1], 0, true, -1, -1), + MessageHandlerResultType.TEXT); + results.add(result); + } else { + // Handler needs wrapping and requires decoder to convert it to one + // of the types expected by the frame handling code + DecoderMatch decoderMatch = matchDecoders(target, endpointConfig); + Method m = getOnMessageMethod(listener); + if (decoderMatch.getBinaryDecoders().size() > 0) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeBinary(listener, m, session, + endpointConfig, + decoderMatch.getBinaryDecoders(), new Object[1], + 0, false, -1, false, -1), + MessageHandlerResultType.BINARY); + results.add(result); + } + if (decoderMatch.getTextDecoders().size() > 0) { + MessageHandlerResult result = new MessageHandlerResult( + new PojoMessageHandlerWholeText(listener, m, session, + endpointConfig, + decoderMatch.getTextDecoders(), new Object[1], + 0, false, -1, -1), + MessageHandlerResultType.TEXT); + results.add(result); + } + } + + if (results.size() == 0) { + throw new IllegalArgumentException( + sm.getString("wsSession.unknownHandler", listener, target)); + } + + return results; + } + + private static List<Class<? extends Decoder>> matchDecoders(Class<?> target, + EndpointConfig endpointConfig, boolean binary) { + DecoderMatch decoderMatch = matchDecoders(target, endpointConfig); + if (binary) { + if (decoderMatch.getBinaryDecoders().size() > 0) { + return decoderMatch.getBinaryDecoders(); + } + } else if (decoderMatch.getTextDecoders().size() > 0) { + return decoderMatch.getTextDecoders(); + } + return null; + } + + private static DecoderMatch matchDecoders(Class<?> target, + EndpointConfig endpointConfig) { + DecoderMatch decoderMatch; + try { + List<Class<? extends Decoder>> decoders = + endpointConfig.getDecoders(); + List<DecoderEntry> decoderEntries = getDecoders(decoders); + decoderMatch = new DecoderMatch(target, decoderEntries); + } catch (DeploymentException e) { + throw new IllegalArgumentException(e); + } + return decoderMatch; + } + + public static void parseExtensionHeader(List<Extension> extensions, + String header) { + // The relevant ABNF for the Sec-WebSocket-Extensions is as follows: + // extension-list = 1#extension + // extension = extension-token *( ";" extension-param ) + // extension-token = registered-token + // registered-token = token + // extension-param = token [ "=" (token | quoted-string) ] + // ; When using the quoted-string syntax variant, the value + // ; after quoted-string unescaping MUST conform to the + // ; 'token' ABNF. + // + // The limiting of parameter values to tokens or "quoted tokens" makes + // the parsing of the header significantly simpler and allows a number + // of short-cuts to be taken. + + // Step one, split the header into individual extensions using ',' as a + // separator + String unparsedExtensions[] = header.split(","); + for (String unparsedExtension : unparsedExtensions) { + // Step two, split the extension into the registered name and + // parameter/value pairs using ';' as a separator + String unparsedParameters[] = unparsedExtension.split(";"); + WsExtension extension = new WsExtension(unparsedParameters[0].trim()); + + for (int i = 1; i < unparsedParameters.length; i++) { + int equalsPos = unparsedParameters[i].indexOf('='); + String name; + String value; + if (equalsPos == -1) { + name = unparsedParameters[i].trim(); + value = null; + } else { + name = unparsedParameters[i].substring(0, equalsPos).trim(); + value = unparsedParameters[i].substring(equalsPos + 1).trim(); + int len = value.length(); + if (len > 1) { + if (value.charAt(0) == '\"' && value.charAt(len - 1) == '\"') { + value = value.substring(1, value.length() - 1); + } + } + } + // Make sure value doesn't contain any of the delimiters since + // that would indicate something went wrong + if (containsDelims(name) || containsDelims(value)) { + throw new IllegalArgumentException(sm.getString( + "util.notToken", name, value)); + } + if (value != null && + (value.indexOf(',') > -1 || value.indexOf(';') > -1 || + value.indexOf('\"') > -1 || value.indexOf('=') > -1)) { + throw new IllegalArgumentException(sm.getString("", value)); + } + extension.addParameter(new WsExtensionParameter(name, value)); + } + extensions.add(extension); + } + } + + + private static boolean containsDelims(String input) { + if (input == null || input.length() == 0) { + return false; + } + for (char c : input.toCharArray()) { + switch (c) { + case ',': + case ';': + case '\"': + case '=': + return true; + default: + // NO_OP + } + + } + return false; + } + + private static Method getOnMessageMethod(MessageHandler listener) { + try { + return listener.getClass().getMethod("onMessage", Object.class); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + sm.getString("util.invalidMessageHandler"), e); + } + } + + private static Method getOnMessagePartialMethod(MessageHandler listener) { + try { + return listener.getClass().getMethod("onMessage", Object.class, Boolean.TYPE); + } catch (NoSuchMethodException | SecurityException e) { + throw new IllegalArgumentException( + sm.getString("util.invalidMessageHandler"), e); + } + } + + + public static class DecoderMatch { + + private final List<Class<? extends Decoder>> textDecoders = + new ArrayList<>(); + private final List<Class<? extends Decoder>> binaryDecoders = + new ArrayList<>(); + private final Class<?> target; + + public DecoderMatch(Class<?> target, List<DecoderEntry> decoderEntries) { + this.target = target; + for (DecoderEntry decoderEntry : decoderEntries) { + if (decoderEntry.getClazz().isAssignableFrom(target)) { + if (Binary.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + binaryDecoders.add(decoderEntry.getDecoderClazz()); + // willDecode() method means this decoder may or may not + // decode a message so need to carry on checking for + // other matches + } else if (BinaryStream.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + binaryDecoders.add(decoderEntry.getDecoderClazz()); + // Stream decoders have to process the message so no + // more decoders can be matched + break; + } else if (Text.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + textDecoders.add(decoderEntry.getDecoderClazz()); + // willDecode() method means this decoder may or may not + // decode a message so need to carry on checking for + // other matches + } else if (TextStream.class.isAssignableFrom( + decoderEntry.getDecoderClazz())) { + textDecoders.add(decoderEntry.getDecoderClazz()); + // Stream decoders have to process the message so no + // more decoders can be matched + break; + } else { + throw new IllegalArgumentException( + sm.getString("util.unknownDecoderType")); + } + } + } + } + + + public List<Class<? extends Decoder>> getTextDecoders() { + return textDecoders; + } + + + public List<Class<? extends Decoder>> getBinaryDecoders() { + return binaryDecoders; + } + + + public Class<?> getTarget() { + return target; + } + + + public boolean hasMatches() { + return (textDecoders.size() > 0) || (binaryDecoders.size() > 0); + } + } + + + private static class TypeResult { + private final Class<?> clazz; + private final int index; + private int dimension; + + public TypeResult(Class<?> clazz, int index, int dimension) { + this.clazz= clazz; + this.index = index; + this.dimension = dimension; + } + + public Class<?> getClazz() { + return clazz; + } + + public int getIndex() { + return index; + } + + public int getDimension() { + return dimension; + } + + public void incrementDimension(int inc) { + dimension += inc; + } + } +} |