diff options
Diffstat (limited to 'src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java')
-rw-r--r-- | src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java | 731 |
1 files changed, 731 insertions, 0 deletions
diff --git a/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java new file mode 100644 index 00000000..2385b5c7 --- /dev/null +++ b/src/java/nginx/unit/websocket/pojo/PojoMethodMapping.java @@ -0,0 +1,731 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package nginx.unit.websocket.pojo; + +import java.io.InputStream; +import java.io.Reader; +import java.lang.annotation.Annotation; +import java.lang.reflect.Method; +import java.lang.reflect.Modifier; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import javax.websocket.CloseReason; +import javax.websocket.DecodeException; +import javax.websocket.Decoder; +import javax.websocket.DeploymentException; +import javax.websocket.EndpointConfig; +import javax.websocket.MessageHandler; +import javax.websocket.OnClose; +import javax.websocket.OnError; +import javax.websocket.OnMessage; +import javax.websocket.OnOpen; +import javax.websocket.PongMessage; +import javax.websocket.Session; +import javax.websocket.server.PathParam; + +import org.apache.tomcat.util.res.StringManager; +import nginx.unit.websocket.DecoderEntry; +import nginx.unit.websocket.Util; +import nginx.unit.websocket.Util.DecoderMatch; + +/** + * For a POJO class annotated with + * {@link javax.websocket.server.ServerEndpoint}, an instance of this class + * creates and caches the method handler, method information and parameter + * information for the onXXX calls. + */ +public class PojoMethodMapping { + + private static final StringManager sm = + StringManager.getManager(PojoMethodMapping.class); + + private final Method onOpen; + private final Method onClose; + private final Method onError; + private final PojoPathParam[] onOpenParams; + private final PojoPathParam[] onCloseParams; + private final PojoPathParam[] onErrorParams; + private final List<MessageHandlerInfo> onMessage = new ArrayList<>(); + private final String wsPath; + + + public PojoMethodMapping(Class<?> clazzPojo, + List<Class<? extends Decoder>> decoderClazzes, String wsPath) + throws DeploymentException { + + this.wsPath = wsPath; + + List<DecoderEntry> decoders = Util.getDecoders(decoderClazzes); + Method open = null; + Method close = null; + Method error = null; + Method[] clazzPojoMethods = null; + Class<?> currentClazz = clazzPojo; + while (!currentClazz.equals(Object.class)) { + Method[] currentClazzMethods = currentClazz.getDeclaredMethods(); + if (currentClazz == clazzPojo) { + clazzPojoMethods = currentClazzMethods; + } + for (Method method : currentClazzMethods) { + if (method.getAnnotation(OnOpen.class) != null) { + checkPublic(method); + if (open == null) { + open = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(open, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnOpen.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnClose.class) != null) { + checkPublic(method); + if (close == null) { + close = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(close, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnClose.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnError.class) != null) { + checkPublic(method); + if (error == null) { + error = method; + } else { + if (currentClazz == clazzPojo || + !isMethodOverride(error, method)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnError.class, currentClazz)); + } + } + } else if (method.getAnnotation(OnMessage.class) != null) { + checkPublic(method); + MessageHandlerInfo messageHandler = new MessageHandlerInfo(method, decoders); + boolean found = false; + for (MessageHandlerInfo otherMessageHandler : onMessage) { + if (messageHandler.targetsSameWebSocketMessageType(otherMessageHandler)) { + found = true; + if (currentClazz == clazzPojo || + !isMethodOverride(messageHandler.m, otherMessageHandler.m)) { + // Duplicate annotation + throw new DeploymentException(sm.getString( + "pojoMethodMapping.duplicateAnnotation", + OnMessage.class, currentClazz)); + } + } + } + if (!found) { + onMessage.add(messageHandler); + } + } else { + // Method not annotated + } + } + currentClazz = currentClazz.getSuperclass(); + } + // If the methods are not on clazzPojo and they are overridden + // by a non annotated method in clazzPojo, they should be ignored + if (open != null && open.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, open, OnOpen.class)) { + open = null; + } + } + if (close != null && close.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, close, OnClose.class)) { + close = null; + } + } + if (error != null && error.getDeclaringClass() != clazzPojo) { + if (isOverridenWithoutAnnotation(clazzPojoMethods, error, OnError.class)) { + error = null; + } + } + List<MessageHandlerInfo> overriddenOnMessage = new ArrayList<>(); + for (MessageHandlerInfo messageHandler : onMessage) { + if (messageHandler.m.getDeclaringClass() != clazzPojo + && isOverridenWithoutAnnotation(clazzPojoMethods, messageHandler.m, OnMessage.class)) { + overriddenOnMessage.add(messageHandler); + } + } + for (MessageHandlerInfo messageHandler : overriddenOnMessage) { + onMessage.remove(messageHandler); + } + this.onOpen = open; + this.onClose = close; + this.onError = error; + onOpenParams = getPathParams(onOpen, MethodType.ON_OPEN); + onCloseParams = getPathParams(onClose, MethodType.ON_CLOSE); + onErrorParams = getPathParams(onError, MethodType.ON_ERROR); + } + + + private void checkPublic(Method m) throws DeploymentException { + if (!Modifier.isPublic(m.getModifiers())) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.methodNotPublic", m.getName())); + } + } + + + private boolean isMethodOverride(Method method1, Method method2) { + return method1.getName().equals(method2.getName()) + && method1.getReturnType().equals(method2.getReturnType()) + && Arrays.equals(method1.getParameterTypes(), method2.getParameterTypes()); + } + + + private boolean isOverridenWithoutAnnotation(Method[] methods, + Method superclazzMethod, Class<? extends Annotation> annotation) { + for (Method method : methods) { + if (isMethodOverride(method, superclazzMethod) + && (method.getAnnotation(annotation) == null)) { + return true; + } + } + return false; + } + + + public String getWsPath() { + return wsPath; + } + + + public Method getOnOpen() { + return onOpen; + } + + + public Object[] getOnOpenArgs(Map<String,String> pathParameters, + Session session, EndpointConfig config) throws DecodeException { + return buildArgs(onOpenParams, pathParameters, session, config, null, + null); + } + + + public Method getOnClose() { + return onClose; + } + + + public Object[] getOnCloseArgs(Map<String,String> pathParameters, + Session session, CloseReason closeReason) throws DecodeException { + return buildArgs(onCloseParams, pathParameters, session, null, null, + closeReason); + } + + + public Method getOnError() { + return onError; + } + + + public Object[] getOnErrorArgs(Map<String,String> pathParameters, + Session session, Throwable throwable) throws DecodeException { + return buildArgs(onErrorParams, pathParameters, session, null, + throwable, null); + } + + + public boolean hasMessageHandlers() { + return !onMessage.isEmpty(); + } + + + public Set<MessageHandler> getMessageHandlers(Object pojo, + Map<String,String> pathParameters, Session session, + EndpointConfig config) { + Set<MessageHandler> result = new HashSet<>(); + for (MessageHandlerInfo messageMethod : onMessage) { + result.addAll(messageMethod.getMessageHandlers(pojo, pathParameters, + session, config)); + } + return result; + } + + + private static PojoPathParam[] getPathParams(Method m, + MethodType methodType) throws DeploymentException { + if (m == null) { + return new PojoPathParam[0]; + } + boolean foundThrowable = false; + Class<?>[] types = m.getParameterTypes(); + Annotation[][] paramsAnnotations = m.getParameterAnnotations(); + PojoPathParam[] result = new PojoPathParam[types.length]; + for (int i = 0; i < types.length; i++) { + Class<?> type = types[i]; + if (type.equals(Session.class)) { + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_OPEN && + type.equals(EndpointConfig.class)) { + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_ERROR + && type.equals(Throwable.class)) { + foundThrowable = true; + result[i] = new PojoPathParam(type, null); + } else if (methodType == MethodType.ON_CLOSE && + type.equals(CloseReason.class)) { + result[i] = new PojoPathParam(type, null); + } else { + Annotation[] paramAnnotations = paramsAnnotations[i]; + for (Annotation paramAnnotation : paramAnnotations) { + if (paramAnnotation.annotationType().equals( + PathParam.class)) { + // Check that the type is valid. "0" coerces to every + // valid type + try { + Util.coerceToType(type, "0"); + } catch (IllegalArgumentException iae) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.invalidPathParamType"), + iae); + } + result[i] = new PojoPathParam(type, + ((PathParam) paramAnnotation).value()); + break; + } + } + // Parameters without annotations are not permitted + if (result[i] == null) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.paramWithoutAnnotation", + type, m.getName(), m.getClass().getName())); + } + } + } + if (methodType == MethodType.ON_ERROR && !foundThrowable) { + throw new DeploymentException(sm.getString( + "pojoMethodMapping.onErrorNoThrowable", + m.getName(), m.getDeclaringClass().getName())); + } + return result; + } + + + private static Object[] buildArgs(PojoPathParam[] pathParams, + Map<String,String> pathParameters, Session session, + EndpointConfig config, Throwable throwable, CloseReason closeReason) + throws DecodeException { + Object[] result = new Object[pathParams.length]; + for (int i = 0; i < pathParams.length; i++) { + Class<?> type = pathParams[i].getType(); + if (type.equals(Session.class)) { + result[i] = session; + } else if (type.equals(EndpointConfig.class)) { + result[i] = config; + } else if (type.equals(Throwable.class)) { + result[i] = throwable; + } else if (type.equals(CloseReason.class)) { + result[i] = closeReason; + } else { + String name = pathParams[i].getName(); + String value = pathParameters.get(name); + try { + result[i] = Util.coerceToType(type, value); + } catch (Exception e) { + throw new DecodeException(value, sm.getString( + "pojoMethodMapping.decodePathParamFail", + value, type), e); + } + } + } + return result; + } + + + private static class MessageHandlerInfo { + + private final Method m; + private int indexString = -1; + private int indexByteArray = -1; + private int indexByteBuffer = -1; + private int indexPong = -1; + private int indexBoolean = -1; + private int indexSession = -1; + private int indexInputStream = -1; + private int indexReader = -1; + private int indexPrimitive = -1; + private Class<?> primitiveType = null; + private Map<Integer,PojoPathParam> indexPathParams = new HashMap<>(); + private int indexPayload = -1; + private DecoderMatch decoderMatch = null; + private long maxMessageSize = -1; + + public MessageHandlerInfo(Method m, List<DecoderEntry> decoderEntries) { + this.m = m; + + Class<?>[] types = m.getParameterTypes(); + Annotation[][] paramsAnnotations = m.getParameterAnnotations(); + + for (int i = 0; i < types.length; i++) { + boolean paramFound = false; + Annotation[] paramAnnotations = paramsAnnotations[i]; + for (Annotation paramAnnotation : paramAnnotations) { + if (paramAnnotation.annotationType().equals( + PathParam.class)) { + indexPathParams.put( + Integer.valueOf(i), new PojoPathParam(types[i], + ((PathParam) paramAnnotation).value())); + paramFound = true; + break; + } + } + if (paramFound) { + continue; + } + if (String.class.isAssignableFrom(types[i])) { + if (indexString == -1) { + indexString = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Reader.class.isAssignableFrom(types[i])) { + if (indexReader == -1) { + indexReader = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (boolean.class == types[i]) { + if (indexBoolean == -1) { + indexBoolean = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateLastParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (ByteBuffer.class.isAssignableFrom(types[i])) { + if (indexByteBuffer == -1) { + indexByteBuffer = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (byte[].class == types[i]) { + if (indexByteArray == -1) { + indexByteArray = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (InputStream.class.isAssignableFrom(types[i])) { + if (indexInputStream == -1) { + indexInputStream = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Util.isPrimitive(types[i])) { + if (indexPrimitive == -1) { + indexPrimitive = i; + primitiveType = types[i]; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (Session.class.isAssignableFrom(types[i])) { + if (indexSession == -1) { + indexSession = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateSessionParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else if (PongMessage.class.isAssignableFrom(types[i])) { + if (indexPong == -1) { + indexPong = i; + } else { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicatePongMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + } else { + if (decoderMatch != null && decoderMatch.hasMatches()) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } + decoderMatch = new DecoderMatch(types[i], decoderEntries); + + if (decoderMatch.hasMatches()) { + indexPayload = i; + } + } + } + + // Additional checks required + if (indexString != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexString; + } + } + if (indexReader != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexReader; + } + } + if (indexByteArray != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexByteArray; + } + } + if (indexByteBuffer != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexByteBuffer; + } + } + if (indexInputStream != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexInputStream; + } + } + if (indexPrimitive != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.duplicateMessageParam", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexPrimitive; + } + } + if (indexPong != -1) { + if (indexPayload != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.pongWithPayload", + m.getName(), m.getDeclaringClass().getName())); + } else { + indexPayload = indexPong; + } + } + if (indexPayload == -1 && indexPrimitive == -1 && + indexBoolean != -1) { + // The boolean we found is a payload, not a last flag + indexPayload = indexBoolean; + indexPrimitive = indexBoolean; + primitiveType = Boolean.TYPE; + indexBoolean = -1; + } + if (indexPayload == -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.noPayload", + m.getName(), m.getDeclaringClass().getName())); + } + if (indexPong != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialPong", + m.getName(), m.getDeclaringClass().getName())); + } + if(indexReader != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialReader", + m.getName(), m.getDeclaringClass().getName())); + } + if(indexInputStream != -1 && indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialInputStream", + m.getName(), m.getDeclaringClass().getName())); + } + if (decoderMatch != null && decoderMatch.hasMatches() && + indexBoolean != -1) { + throw new IllegalArgumentException(sm.getString( + "pojoMethodMapping.partialObject", + m.getName(), m.getDeclaringClass().getName())); + } + + maxMessageSize = m.getAnnotation(OnMessage.class).maxMessageSize(); + } + + + public boolean targetsSameWebSocketMessageType(MessageHandlerInfo otherHandler) { + if (otherHandler == null) { + return false; + } + if (indexByteArray >= 0 && otherHandler.indexByteArray >= 0) { + return true; + } + if (indexByteBuffer >= 0 && otherHandler.indexByteBuffer >= 0) { + return true; + } + if (indexInputStream >= 0 && otherHandler.indexInputStream >= 0) { + return true; + } + if (indexPong >= 0 && otherHandler.indexPong >= 0) { + return true; + } + if (indexPrimitive >= 0 && otherHandler.indexPrimitive >= 0 + && primitiveType == otherHandler.primitiveType) { + return true; + } + if (indexReader >= 0 && otherHandler.indexReader >= 0) { + return true; + } + if (indexString >= 0 && otherHandler.indexString >= 0) { + return true; + } + if (decoderMatch != null && otherHandler.decoderMatch != null + && decoderMatch.getTarget().equals(otherHandler.decoderMatch.getTarget())) { + return true; + } + return false; + } + + + public Set<MessageHandler> getMessageHandlers(Object pojo, + Map<String,String> pathParameters, Session session, + EndpointConfig config) { + Object[] params = new Object[m.getParameterTypes().length]; + + for (Map.Entry<Integer,PojoPathParam> entry : + indexPathParams.entrySet()) { + PojoPathParam pathParam = entry.getValue(); + String valueString = pathParameters.get(pathParam.getName()); + Object value = null; + try { + value = Util.coerceToType(pathParam.getType(), valueString); + } catch (Exception e) { + DecodeException de = new DecodeException(valueString, + sm.getString( + "pojoMethodMapping.decodePathParamFail", + valueString, pathParam.getType()), e); + params = new Object[] { de }; + break; + } + params[entry.getKey().intValue()] = value; + } + + Set<MessageHandler> results = new HashSet<>(2); + if (indexBoolean == -1) { + // Basic + if (indexString != -1 || indexPrimitive != -1) { + MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, + session, config, null, params, indexPayload, false, + indexSession, maxMessageSize); + results.add(mh); + } else if (indexReader != -1) { + MessageHandler mh = new PojoMessageHandlerWholeText(pojo, m, + session, config, null, params, indexReader, true, + indexSession, maxMessageSize); + results.add(mh); + } else if (indexByteArray != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexByteArray, + true, indexSession, false, maxMessageSize); + results.add(mh); + } else if (indexByteBuffer != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexByteBuffer, + false, indexSession, false, maxMessageSize); + results.add(mh); + } else if (indexInputStream != -1) { + MessageHandler mh = new PojoMessageHandlerWholeBinary(pojo, + m, session, config, null, params, indexInputStream, + true, indexSession, true, maxMessageSize); + results.add(mh); + } else if (decoderMatch != null && decoderMatch.hasMatches()) { + if (decoderMatch.getBinaryDecoders().size() > 0) { + MessageHandler mh = new PojoMessageHandlerWholeBinary( + pojo, m, session, config, + decoderMatch.getBinaryDecoders(), params, + indexPayload, true, indexSession, true, + maxMessageSize); + results.add(mh); + } + if (decoderMatch.getTextDecoders().size() > 0) { + MessageHandler mh = new PojoMessageHandlerWholeText( + pojo, m, session, config, + decoderMatch.getTextDecoders(), params, + indexPayload, true, indexSession, maxMessageSize); + results.add(mh); + } + } else { + MessageHandler mh = new PojoMessageHandlerWholePong(pojo, m, + session, params, indexPong, false, indexSession); + results.add(mh); + } + } else { + // ASync + if (indexString != -1) { + MessageHandler mh = new PojoMessageHandlerPartialText(pojo, + m, session, params, indexString, false, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } else if (indexByteArray != -1) { + MessageHandler mh = new PojoMessageHandlerPartialBinary( + pojo, m, session, params, indexByteArray, true, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } else { + MessageHandler mh = new PojoMessageHandlerPartialBinary( + pojo, m, session, params, indexByteBuffer, false, + indexBoolean, indexSession, maxMessageSize); + results.add(mh); + } + } + return results; + } + } + + + private enum MethodType { + ON_OPEN, + ON_CLOSE, + ON_ERROR + } +} |