Add PrincipalMessageArgumentResolver

This commit is contained in:
Rossen Stoyanchev
2013-07-12 15:15:37 -04:00
parent d3cecfc6cc
commit 210be9cde4
8 changed files with 146 additions and 52 deletions

View File

@@ -0,0 +1,54 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.handler.method;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessagingException;
/**
*
* @author Rossen Stoyanchev
* @since 4.0
*/
public class InvalidMessageMethodParameterException extends MessagingException {
private static final long serialVersionUID = -6905878930083523161L;
private final MethodParameter parameter;
public InvalidMessageMethodParameterException(Message<?> message, String description,
MethodParameter parameter, Throwable cause) {
super(message, description, cause);
this.parameter = parameter;
}
public InvalidMessageMethodParameterException(Message<?> message, String description,
MethodParameter parameter) {
super(message, description);
this.parameter = parameter;
}
public MethodParameter getParameter() {
return this.parameter;
}
}

View File

@@ -16,12 +16,14 @@
package org.springframework.messaging.simp;
import java.security.Principal;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.springframework.http.MediaType;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.support.NativeMessageHeaderAccessor;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
@@ -43,9 +45,6 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String DESTINATIONS = "destinations";
// TODO
public static final String CONTENT_TYPE = "contentType";
public static final String MESSAGE_TYPE = "messageType";
public static final String PROTOCOL_MESSAGE_TYPE = "protocolMessageType";
@@ -54,6 +53,8 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
public static final String SUBSCRIPTION_ID = "subscriptionId";
public static final String USER = "user";
/**
* A constructor for creating new message headers.
@@ -140,12 +141,11 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
}
public MediaType getContentType() {
return (MediaType) getHeader(CONTENT_TYPE);
return (MediaType) getHeader(MessageHeaders.CONTENT_TYPE);
}
public void setContentType(MediaType contentType) {
Assert.notNull(contentType, "contentType is required");
setHeader(CONTENT_TYPE, contentType);
setHeader(MessageHeaders.CONTENT_TYPE, contentType);
}
public String getSubscriptionId() {
@@ -164,4 +164,12 @@ public class SimpMessageHeaderAccessor extends NativeMessageHeaderAccessor {
setHeader(SESSION_ID, sessionId);
}
public Principal getUser() {
return (Principal) getHeader(USER);
}
public void setUser(Principal principal) {
setHeader(USER, principal);
}
}

View File

@@ -0,0 +1,51 @@
/*
* Copyright 2002-2013 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.messaging.simp.annotation.support;
import java.security.Principal;
import org.springframework.core.MethodParameter;
import org.springframework.messaging.Message;
import org.springframework.messaging.handler.method.InvalidMessageMethodParameterException;
import org.springframework.messaging.handler.method.MessageArgumentResolver;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
/**
* @author Rossen Stoyanchev
* @since 4.0
*/
public class PrincipalMessageArgumentResolver implements MessageArgumentResolver {
@Override
public boolean supportsParameter(MethodParameter parameter) {
Class<?> paramType = parameter.getParameterType();
return Principal.class.isAssignableFrom(paramType);
}
@Override
public Object resolveArgument(MethodParameter parameter, Message<?> message) throws Exception {
SimpMessageHeaderAccessor headers = SimpMessageHeaderAccessor.wrap(message);
Principal user = headers.getUser();
if (user == null) {
throw new InvalidMessageMethodParameterException(message, "User not available", parameter);
}
return user;
}
}

View File

@@ -37,15 +37,16 @@ import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.support.MessageBodyArgumentResolver;
import org.springframework.messaging.handler.annotation.support.MessageExceptionHandlerMethodResolver;
import org.springframework.messaging.handler.method.MessageArgumentResolverComposite;
import org.springframework.messaging.handler.method.InvocableMessageHandlerMethod;
import org.springframework.messaging.handler.method.MessageArgumentResolverComposite;
import org.springframework.messaging.handler.method.MessageReturnValueHandlerComposite;
import org.springframework.messaging.simp.annotation.SubscribeEvent;
import org.springframework.messaging.simp.annotation.UnsubscribeEvent;
import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler;
import org.springframework.messaging.simp.MessageHolder;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.annotation.SubscribeEvent;
import org.springframework.messaging.simp.annotation.UnsubscribeEvent;
import org.springframework.messaging.simp.annotation.support.MessageSendingReturnValueHandler;
import org.springframework.messaging.simp.annotation.support.PrincipalMessageArgumentResolver;
import org.springframework.messaging.support.converter.MessageConverter;
import org.springframework.stereotype.Controller;
import org.springframework.util.Assert;
@@ -113,6 +114,7 @@ public class AnnotationSimpMessageHandler extends AbstractSimpMessageHandler
initHandlerMethods();
this.argumentResolvers.addResolver(new PrincipalMessageArgumentResolver());
this.argumentResolvers.addResolver(new MessageBodyArgumentResolver(this.messageConverter));
this.returnValueHandlers.addHandler(

View File

@@ -46,7 +46,7 @@ public class StompMessageConverter {
/**
* @param stompContent a complete STOMP message (without the trailing 0x00) as byte[] or String.
*/
public Message<?> toMessage(Object stompContent, String sessionId) {
public Message<?> toMessage(Object stompContent) {
byte[] byteContent = null;
if (stompContent instanceof String) {
@@ -91,12 +91,10 @@ public class StompMessageConverter {
}
}
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
stompHeaders.setSessionId(sessionId);
byte[] payload = new byte[totalLength - payloadIndex];
System.arraycopy(byteContent, payloadIndex, payload, 0, totalLength - payloadIndex);
StompHeaderAccessor stompHeaders = StompHeaderAccessor.create(command, headers);
return MessageBuilder.withPayload(payload).copyHeaders(stompHeaders.toMap()).build();
}

View File

@@ -29,9 +29,9 @@ import java.util.concurrent.TimeUnit;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler;
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
import org.springframework.messaging.simp.SimpMessageType;
import org.springframework.messaging.simp.handler.AbstractSimpMessageHandler;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
@@ -350,7 +350,7 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
return;
}
Message<?> message = stompMessageConverter.toMessage(stompFrame, this.sessionId);
Message<?> message = stompMessageConverter.toMessage(stompFrame);
if (logger.isTraceEnabled()) {
logger.trace("Reading message " + message);
}
@@ -369,6 +369,10 @@ public class StompRelayMessageHandler extends AbstractSimpMessageHandler impleme
}
relaySessions.remove(this.sessionId);
}
headers.setSessionId(this.sessionId);
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
sendMessageToClient(message);
}

View File

@@ -81,7 +81,11 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
protected void handleTextMessage(WebSocketSession session, TextMessage textMessage) {
try {
String payload = textMessage.getPayload();
Message<?> message = this.stompMessageConverter.toMessage(payload, session.getId());
Message<?> message = this.stompMessageConverter.toMessage(payload);
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
headers.setSessionId(session.getId());
headers.setUser(session.getPrincipal());
// TODO: validate size limits
// http://stomp.github.io/stomp-specification-1.2.html#Size_Limits
@@ -96,18 +100,8 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
if (SimpMessageType.CONNECT.equals(messageType)) {
handleConnect(session, message);
}
else if (SimpMessageType.MESSAGE.equals(messageType)) {
handlePublish(message);
}
else if (SimpMessageType.SUBSCRIBE.equals(messageType)) {
handleSubscribe(message);
}
else if (SimpMessageType.UNSUBSCRIBE.equals(messageType)) {
handleUnsubscribe(message);
}
else if (SimpMessageType.DISCONNECT.equals(messageType)) {
handleDisconnect(message);
}
message = MessageBuilder.fromMessage(message).copyHeaders(headers.toMap()).build();
this.outputChannel.send(message);
}
catch (Throwable t) {
@@ -124,7 +118,7 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
}
}
protected void handleConnect(final WebSocketSession session, Message<?> message) throws IOException {
protected void handleConnect(WebSocketSession session, Message<?> message) throws IOException {
StompHeaderAccessor connectHeaders = StompHeaderAccessor.wrap(message);
StompHeaderAccessor connectedHeaders = StompHeaderAccessor.create(StompCommand.CONNECTED);
@@ -152,18 +146,6 @@ public class StompWebSocketHandler extends TextWebSocketHandlerAdapter implement
session.sendMessage(new TextMessage(new String(bytes, Charset.forName("UTF-8"))));
}
protected void handlePublish(Message<?> stompMessage) {
}
protected void handleSubscribe(Message<?> message) {
}
protected void handleUnsubscribe(Message<?> message) {
}
protected void handleDisconnect(Message<?> message) {
}
protected void sendErrorMessage(WebSocketSession session, Throwable error) {
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.ERROR);