ソースを参照

优化websocket

zhou-hao 7 年 前
コミット
09f6f42d15

+ 4 - 1
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/AuthorizeCommandProcessor.java

@@ -36,8 +36,11 @@ public class AuthorizeCommandProcessor extends AbstractCommandProcessor {
             if (token != null) {
                 UserTokenHolder.setCurrent(token);
                 success = Authentication.current().orElse(null) != null;
+                if (success) {
+                    command.getSession().getAttributes().put("user_token", accessToken);
+                }
             }
-            sendMessage(command.getSession(),new WebSocketMessage(200, callback, success));
+            sendMessage(command.getSession(), new WebSocketMessage(200, callback, success));
         }
     }
 }

+ 38 - 7
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/handler/CommandWebSocketMessageDispatcher.java

@@ -5,6 +5,8 @@ import com.fasterxml.jackson.core.JsonParseException;
 import org.hswebframework.web.ThreadLocalUtils;
 import org.hswebframework.web.authorization.Authentication;
 import org.hswebframework.web.authorization.AuthenticationHolder;
+import org.hswebframework.web.authorization.exception.AccessDenyException;
+import org.hswebframework.web.authorization.exception.UnAuthorizedException;
 import org.hswebframework.web.authorization.token.UserToken;
 import org.hswebframework.web.authorization.token.UserTokenHolder;
 import org.hswebframework.web.authorization.token.UserTokenManager;
@@ -23,6 +25,7 @@ import org.springframework.web.socket.TextMessage;
 import org.springframework.web.socket.WebSocketSession;
 import org.springframework.web.socket.handler.TextWebSocketHandler;
 
+import java.nio.file.AccessDeniedException;
 import java.util.List;
 import java.util.Map;
 import java.util.Objects;
@@ -38,15 +41,18 @@ public class CommandWebSocketMessageDispatcher extends TextWebSocketHandler {
 
     private List<WebSocketSessionListener> webSocketSessionListeners;
 
-    @Autowired(required = false)
     private List<WebSocketTokenParser> tokenParsers;
 
-    private Logger logger= LoggerFactory.getLogger(this.getClass());
+    private Logger logger = LoggerFactory.getLogger(this.getClass());
 
     public void setWebSocketSessionListeners(List<WebSocketSessionListener> webSocketSessionListeners) {
         this.webSocketSessionListeners = webSocketSessionListeners;
     }
 
+    public void setTokenParsers(List<WebSocketTokenParser> tokenParsers) {
+        this.tokenParsers = tokenParsers;
+    }
+
     public void setUserTokenManager(UserTokenManager userTokenManager) {
         this.userTokenManager = userTokenManager;
     }
@@ -65,8 +71,11 @@ public class CommandWebSocketMessageDispatcher extends TextWebSocketHandler {
         if (StringUtils.isEmpty(payload)) {
             return;
         }
+        String cmd = null;
+        WebSocketMessage errorMessage = null;
         try {
             WebSocketCommandRequest request = JSON.parseObject(payload, WebSocketCommandRequest.class);
+            cmd = request.getCommand();
             CommandRequest command = buildCommand(request, session);
             CommandProcessor processor = processorContainer.getProcessor(request.getCommand());
             if (processor != null) {
@@ -76,24 +85,44 @@ public class CommandWebSocketMessageDispatcher extends TextWebSocketHandler {
             }
         } catch (JsonParseException e) {
             session.sendMessage(requestFormatErrorMessage);
+        } catch (UnAuthorizedException e) {
+            errorMessage = new WebSocketMessage(401, "un authorized");
+        } catch (AccessDenyException e) {
+            errorMessage = new WebSocketMessage(403, "access deny");
         } catch (Exception e) {
-            e.printStackTrace();
-            session.sendMessage(new TextMessage(new WebSocketMessage(500, "error!" + e.getMessage()).toString()));
+            logger.warn("handle websocket message error ", e);
+            errorMessage = new WebSocketMessage(500, e.getMessage());
+        } finally {
+            ThreadLocalUtils.clear();
+        }
+        if (errorMessage != null) {
+            errorMessage.setCommand(cmd);
+            session.sendMessage(new TextMessage(errorMessage.toString()));
         }
     }
 
-    private Authentication getAuthenticationFromSession(WebSocketSession socketSession) {
+    private Authentication getAuthenticationFromSession(WebSocketSession session) {
         if (null == userTokenManager) {
             return null;
         }
+        String token = (String) session.getAttributes().get("user_token");
+        if(null==token){
+            return null;
+        }
+        UserToken userToken = userTokenManager.getByToken(token);
+        if (null == userToken) {
+            return null;
+        }
+        UserTokenHolder.setCurrent(userToken);
         return Authentication.current().orElse(null);
     }
 
     private CommandRequest buildCommand(WebSocketCommandRequest request, WebSocketSession socketSession) {
+        Authentication authentication = getAuthenticationFromSession(socketSession);
         return new CommandRequest() {
             @Override
             public Authentication getAuthentication() {
-                return getAuthenticationFromSession(socketSession);
+                return authentication;
             }
 
             @Override
@@ -121,9 +150,11 @@ public class CommandWebSocketMessageDispatcher extends TextWebSocketHandler {
                 if (null != userToken) {
                     UserTokenHolder.setCurrent(userToken);
                     Authentication authentication = Authentication.current().orElse(null);
+                    session.getAttributes().put("user_token", token);
+
                     if (null != authentication) {
                         logger.debug("websocket authentication init ok!");
-                    }else{
+                    } else {
                         logger.debug("websocket authentication init fail!");
                     }
                 }

+ 5 - 0
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/starter/CommandWebSocketAutoConfiguration.java

@@ -6,6 +6,7 @@ import org.hswebframework.web.message.Messager;
 import org.hswebframework.web.socket.WebSocketSessionListener;
 import org.hswebframework.web.socket.authorize.AuthorizeCommandProcessor;
 import org.hswebframework.web.socket.authorize.SessionIdWebSocketTokenParser;
+import org.hswebframework.web.socket.authorize.WebSocketTokenParser;
 import org.hswebframework.web.socket.authorize.XAccessTokenParser;
 import org.hswebframework.web.socket.handler.CommandWebSocketMessageDispatcher;
 import org.hswebframework.web.socket.message.DefaultWebSocketMessager;
@@ -82,6 +83,9 @@ public class CommandWebSocketAutoConfiguration {
         @Autowired(required = false)
         private List<WebSocketSessionListener> webSocketSessionListeners;
 
+        @Autowired(required = false)
+        private List<WebSocketTokenParser> webSocketTokenParsers;
+
         @Autowired
         private CommandProcessorContainer commandProcessorContainer;
 
@@ -91,6 +95,7 @@ public class CommandWebSocketAutoConfiguration {
             dispatcher.setProcessorContainer(commandProcessorContainer);
             dispatcher.setUserTokenManager(userTokenManager);
             dispatcher.setWebSocketSessionListeners(webSocketSessionListeners);
+            dispatcher.setTokenParsers(webSocketTokenParsers);
             registry.addHandler(dispatcher, "/sockjs")
                     .withSockJS()
                     .setSessionCookieNeeded(true);