Browse Source

优化websocket

zhou-hao 7 years ago
parent
commit
baf88d1e59
13 changed files with 229 additions and 40 deletions
  1. 0 2
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/CommandRequest.java
  2. 43 0
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/AuthorizeCommandProcessor.java
  3. 27 0
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/SessionIdWebSocketTokenParser.java
  4. 7 0
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/WebSocketTokenParser.java
  5. 13 0
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/XAccessTokenParser.java
  6. 35 1
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/handler/CommandWebSocketMessageDispatcher.java
  7. 0 2
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/handler/WebSocketCommandRequest.java
  8. 2 2
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/message/DefaultWebSocketMessager.java
  9. 27 15
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/message/WebSocketMessage.java
  10. 34 0
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/processor/AbstractCommandProcessor.java
  11. 19 0
      hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/starter/CommandWebSocketAutoConfiguration.java
  12. 11 12
      hsweb-message/hsweb-message-websocket/src/test/java/org/hswebframework/web/socket/WebSocketClientTests.java
  13. 11 6
      hsweb-message/hsweb-message-websocket/src/test/java/org/hswebframework/web/socket/WebSocketServerTests.java

+ 0 - 2
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/CommandRequest.java

@@ -6,8 +6,6 @@ import org.springframework.web.socket.WebSocketSession;
 import java.util.Map;
 
 /**
- * TODO 完成注释
- *
  * @author zhouhao
  */
 public interface CommandRequest {

+ 43 - 0
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/AuthorizeCommandProcessor.java

@@ -0,0 +1,43 @@
+package org.hswebframework.web.socket.authorize;
+
+import org.hswebframework.web.authorization.Authentication;
+import org.hswebframework.web.authorization.token.UserToken;
+import org.hswebframework.web.authorization.token.UserTokenHolder;
+import org.hswebframework.web.authorization.token.UserTokenManager;
+import org.hswebframework.web.socket.CommandRequest;
+import org.hswebframework.web.socket.message.WebSocketMessage;
+import org.hswebframework.web.socket.processor.AbstractCommandProcessor;
+
+public class AuthorizeCommandProcessor extends AbstractCommandProcessor {
+
+    private UserTokenManager userTokenManager;
+
+    public AuthorizeCommandProcessor(UserTokenManager userTokenManager) {
+        this.userTokenManager = userTokenManager;
+    }
+
+    public void setUserTokenManager(UserTokenManager userTokenManager) {
+        this.userTokenManager = userTokenManager;
+    }
+
+    @Override
+    public String getName() {
+        return "authorize";
+    }
+
+    @Override
+    public void execute(CommandRequest command) {
+        String accessToken = (String) command.getParameters().get("access_token");
+        String callback = (String) command.getParameters().getOrDefault("callback", "authorize");
+        boolean success = false;
+
+        if (null != accessToken) {
+            UserToken token = userTokenManager.getByToken(accessToken);
+            if (token != null) {
+                UserTokenHolder.setCurrent(token);
+                success = Authentication.current().orElse(null) != null;
+            }
+            sendMessage(command.getSession(),new WebSocketMessage(200, callback, success));
+        }
+    }
+}

+ 27 - 0
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/SessionIdWebSocketTokenParser.java

@@ -0,0 +1,27 @@
+package org.hswebframework.web.socket.authorize;
+
+import org.springframework.http.HttpHeaders;
+import org.springframework.web.socket.WebSocketSession;
+
+import java.util.*;
+
+public class SessionIdWebSocketTokenParser implements WebSocketTokenParser {
+    @Override
+    public String parseToken(WebSocketSession session) {
+        HttpHeaders headers = session.getHandshakeHeaders();
+        List<String> cookies = headers.get("Cookie");
+        if (cookies == null || cookies.isEmpty()) {
+            return null;
+        }
+        String[] cookie = cookies.get(0).split("[;]");
+        Map<String, Set<String>> sessionId = new HashMap<>();
+        for (String aCookie : cookie) {
+            String[] tmp = aCookie.split("[=]");
+            if (tmp.length == 2) {
+                sessionId.computeIfAbsent(tmp[0].trim().toUpperCase(), k -> new HashSet<>())
+                        .add(tmp[1].trim());
+            }
+        }
+        return sessionId.getOrDefault("JSESSIONID", sessionId.getOrDefault("SESSIONID", Collections.emptySet())).stream().findFirst().orElse(null);
+    }
+}

+ 7 - 0
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/WebSocketTokenParser.java

@@ -0,0 +1,7 @@
+package org.hswebframework.web.socket.authorize;
+
+import org.springframework.web.socket.WebSocketSession;
+
+public interface WebSocketTokenParser {
+    String parseToken(WebSocketSession session);
+}

+ 13 - 0
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/authorize/XAccessTokenParser.java

@@ -0,0 +1,13 @@
+package org.hswebframework.web.socket.authorize;
+
+import org.springframework.web.socket.WebSocketSession;
+
+import java.util.List;
+
+public class XAccessTokenParser implements WebSocketTokenParser {
+    @Override
+    public String parseToken(WebSocketSession session) {
+        List<String> tokens = session.getHandshakeHeaders().get("x-access-token");
+        return tokens == null || tokens.isEmpty() ? null : tokens.get(0);
+    }
+}

+ 35 - 1
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/handler/CommandWebSocketMessageDispatcher.java

@@ -2,13 +2,21 @@ package org.hswebframework.web.socket.handler;
 
 import com.alibaba.fastjson.JSON;
 import com.fasterxml.jackson.core.JsonParseException;
+import org.hswebframework.web.ThreadLocalUtils;
 import org.hswebframework.web.authorization.Authentication;
+import org.hswebframework.web.authorization.AuthenticationHolder;
+import org.hswebframework.web.authorization.token.UserToken;
+import org.hswebframework.web.authorization.token.UserTokenHolder;
 import org.hswebframework.web.authorization.token.UserTokenManager;
 import org.hswebframework.web.socket.CommandRequest;
 import org.hswebframework.web.socket.WebSocketSessionListener;
+import org.hswebframework.web.socket.authorize.WebSocketTokenParser;
 import org.hswebframework.web.socket.message.WebSocketMessage;
 import org.hswebframework.web.socket.processor.CommandProcessor;
 import org.hswebframework.web.socket.processor.CommandProcessorContainer;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.util.StringUtils;
 import org.springframework.web.socket.CloseStatus;
 import org.springframework.web.socket.TextMessage;
@@ -17,6 +25,7 @@ import org.springframework.web.socket.handler.TextWebSocketHandler;
 
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 /**
  * @author zhouhao
@@ -29,6 +38,11 @@ public class CommandWebSocketMessageDispatcher extends TextWebSocketHandler {
 
     private List<WebSocketSessionListener> webSocketSessionListeners;
 
+    @Autowired(required = false)
+    private List<WebSocketTokenParser> tokenParsers;
+
+    private Logger logger= LoggerFactory.getLogger(this.getClass());
+
     public void setWebSocketSessionListeners(List<WebSocketSessionListener> webSocketSessionListeners) {
         this.webSocketSessionListeners = webSocketSessionListeners;
     }
@@ -72,7 +86,7 @@ public class CommandWebSocketMessageDispatcher extends TextWebSocketHandler {
         if (null == userTokenManager) {
             return null;
         }
-        return WebSocketUtils.getAuthentication(userTokenManager, socketSession);
+        return Authentication.current().orElse(null);
     }
 
     private CommandRequest buildCommand(WebSocketCommandRequest request, WebSocketSession socketSession) {
@@ -96,6 +110,25 @@ public class CommandWebSocketMessageDispatcher extends TextWebSocketHandler {
 
     @Override
     public void afterConnectionEstablished(WebSocketSession session) throws Exception {
+        if (tokenParsers != null) {
+            String token = tokenParsers.stream()
+                    .map(parser -> parser.parseToken(session))
+                    .filter(Objects::nonNull)
+                    .findFirst()
+                    .orElse(null);
+            if (null != token) {
+                UserToken userToken = userTokenManager.getByToken(token);
+                if (null != userToken) {
+                    UserTokenHolder.setCurrent(userToken);
+                    Authentication authentication = Authentication.current().orElse(null);
+                    if (null != authentication) {
+                        logger.debug("websocket authentication init ok!");
+                    }else{
+                        logger.debug("websocket authentication init fail!");
+                    }
+                }
+            }
+        }
         if (webSocketSessionListeners != null) {
             webSocketSessionListeners.forEach(webSocketSessionListener ->
                     webSocketSessionListener.onSessionConnect(session));
@@ -104,6 +137,7 @@ public class CommandWebSocketMessageDispatcher extends TextWebSocketHandler {
 
     @Override
     public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
+        ThreadLocalUtils.clear();
         if (webSocketSessionListeners != null) {
             webSocketSessionListeners.forEach(webSocketSessionListener ->
                     webSocketSessionListener.onSessionClose(session));

+ 0 - 2
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/handler/WebSocketCommandRequest.java

@@ -3,8 +3,6 @@ package org.hswebframework.web.socket.handler;
 import java.util.Map;
 
 /**
- * TODO 完成注释
- *
  * @author zhouhao
  */
 public class WebSocketCommandRequest {

+ 2 - 2
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/message/DefaultWebSocketMessager.java

@@ -27,6 +27,7 @@ import static org.hswebframework.web.message.builder.StaticMessageSubjectBuilder
 public class DefaultWebSocketMessager implements WebSocketMessager {
 
     private Messager messager;
+    private CounterManager counterManager;
 
     public DefaultWebSocketMessager(Messager messager) {
         this(messager, new SimpleCounterManager());
@@ -34,13 +35,12 @@ public class DefaultWebSocketMessager implements WebSocketMessager {
 
     public DefaultWebSocketMessager(Messager messager, CounterManager counterManager) {
         this.messager = messager;
-        this.counterManager = counterManager == null ? new SimpleCounterManager() : counterManager;
+        this.counterManager = counterManager;
     }
 
     //              command,   type,     sessionId
     private final Map<String, Map<String, Map<String, MessageSubscribeSession>>> store = new ConcurrentHashMap<>(32);
 
-    private CounterManager counterManager = new SimpleCounterManager();
 
 
     @Override

+ 27 - 15
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/message/WebSocketMessage.java

@@ -8,18 +8,22 @@ import java.io.Serializable;
  * @author zhouhao
  */
 public class WebSocketMessage implements Serializable {
-    private int code;
+    private static final long serialVersionUID = -1173161338949028545L;
+
+    private String command;
+
+    private int status;
 
     private String message;
 
-    private Object data;
+    private Object result;
 
-    public int getCode() {
-        return code;
+    public int getStatus() {
+        return status;
     }
 
-    public void setCode(int code) {
-        this.code = code;
+    public void setStatus(int status) {
+        this.status = status;
     }
 
     public String getMessage() {
@@ -30,12 +34,20 @@ public class WebSocketMessage implements Serializable {
         this.message = message;
     }
 
-    public Object getData() {
-        return data;
+    public Object getResult() {
+        return result;
+    }
+
+    public void setResult(Object result) {
+        this.result = result;
+    }
+
+    public String getCommand() {
+        return command;
     }
 
-    public void setData(Object data) {
-        this.data = data;
+    public void setCommand(String command) {
+        this.command = command;
     }
 
     @Override
@@ -46,14 +58,14 @@ public class WebSocketMessage implements Serializable {
     public WebSocketMessage() {
     }
 
-    public WebSocketMessage(int code, String message) {
-        this.code = code;
+    public WebSocketMessage(int status, String message) {
+        this.status = status;
         this.message = message;
     }
 
-    public WebSocketMessage(int code, String message, Object data) {
-        this.code = code;
+    public WebSocketMessage(int status, String message, Object result) {
+        this.status = status;
         this.message = message;
-        this.data = data;
+        this.result = result;
     }
 }

+ 34 - 0
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/processor/AbstractCommandProcessor.java

@@ -0,0 +1,34 @@
+package org.hswebframework.web.socket.processor;
+
+import org.hswebframework.web.socket.message.WebSocketMessage;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.web.socket.TextMessage;
+import org.springframework.web.socket.WebSocketSession;
+
+import java.io.IOException;
+
+public abstract class AbstractCommandProcessor implements CommandProcessor {
+
+    protected Logger logger = LoggerFactory.getLogger(this.getClass());
+
+
+    protected void sendMessage(WebSocketSession session, WebSocketMessage message) {
+        message.setCommand(getName());
+        try {
+            session.sendMessage(new TextMessage(message.toString()));
+        } catch (IOException e) {
+            logger.error("send websocket message to {} error", session.getId(), message.toString(), e);
+        }
+    }
+
+    @Override
+    public void init() {
+
+    }
+
+    @Override
+    public void destroy() {
+
+    }
+}

+ 19 - 0
hsweb-message/hsweb-message-websocket/src/main/java/org/hswebframework/web/socket/starter/CommandWebSocketAutoConfiguration.java

@@ -4,6 +4,9 @@ import org.hswebframework.web.authorization.token.UserTokenManager;
 import org.hswebframework.web.concurrent.counter.CounterManager;
 import org.hswebframework.web.message.Messager;
 import org.hswebframework.web.socket.WebSocketSessionListener;
+import org.hswebframework.web.socket.authorize.AuthorizeCommandProcessor;
+import org.hswebframework.web.socket.authorize.SessionIdWebSocketTokenParser;
+import org.hswebframework.web.socket.authorize.XAccessTokenParser;
 import org.hswebframework.web.socket.handler.CommandWebSocketMessageDispatcher;
 import org.hswebframework.web.socket.message.DefaultWebSocketMessager;
 import org.hswebframework.web.socket.message.WebSocketMessager;
@@ -28,6 +31,22 @@ import java.util.List;
 @Configuration
 public class CommandWebSocketAutoConfiguration {
 
+    @Bean
+    public SessionIdWebSocketTokenParser sessionIdWebSocketTokenParser(){
+        return new SessionIdWebSocketTokenParser();
+    }
+
+    @Bean
+    public XAccessTokenParser xAccessTokenParser(){
+        return new XAccessTokenParser();
+    }
+
+    @Bean
+    @ConditionalOnBean(UserTokenManager.class)
+    public AuthorizeCommandProcessor authorizeCommandProcessor(UserTokenManager userTokenManager){
+        return new AuthorizeCommandProcessor(userTokenManager);
+    }
+
     @Configuration
     @ConditionalOnMissingBean(CommandProcessorContainer.class)
     public static class WebSocketProcessorContainerConfiguration {

+ 11 - 12
hsweb-message/hsweb-message-websocket/src/test/java/org/hswebframework/web/socket/WebSocketClientTests.java

@@ -10,19 +10,18 @@ import org.springframework.web.socket.handler.AbstractWebSocketHandler;
 
 public class WebSocketClientTests {
     public static void main(String[] args) throws Exception {
-//        for (int i = 0; i < 10; i++) {
-            WebSocketClient client = new StandardWebSocketClient();
-            String url = "ws://localhost:8081/socket";
-            ListenableFuture<WebSocketSession> future = client.doHandshake(new AbstractWebSocketHandler() {
-                @Override
-                public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
-                    System.out.println(message.getPayload());
-                }
-            }, url);
+        WebSocketClient client = new StandardWebSocketClient();
+        String url = "ws://localhost:8081/socket";
+        ListenableFuture<WebSocketSession> future = client.doHandshake(new AbstractWebSocketHandler() {
+            @Override
+            public void handleMessage(WebSocketSession session, WebSocketMessage<?> message) throws Exception {
+                System.out.println(message.getPayload());
+            }
+        }, url);
+
+        WebSocketSession socketSession = future.get();
+        socketSession.sendMessage(new TextMessage("{\"command\":\"test\",\"parameters\":{\"type\":\"conn\"}}"));
 
-            WebSocketSession socketSession = future.get();
-            socketSession.sendMessage(new TextMessage("{\"command\":\"test\",\"parameters\":{\"type\":\"conn\"}}"));
-//        }
         System.in.read();
     }
 }

+ 11 - 6
hsweb-message/hsweb-message-websocket/src/test/java/org/hswebframework/web/socket/WebSocketServerTests.java

@@ -5,6 +5,7 @@ import org.hswebframework.web.concurrent.counter.CounterManager;
 import org.hswebframework.web.counter.redis.RedissonCounterManager;
 import org.hswebframework.web.message.Messager;
 import org.hswebframework.web.message.jms.JmsMessager;
+import org.hswebframework.web.message.redis.RedissonMessager;
 import org.redisson.Redisson;
 import org.redisson.api.RedissonClient;
 import org.redisson.config.Config;
@@ -30,8 +31,15 @@ public class WebSocketServerTests {
     }
 
     @Bean
-    public Messager messager(JmsTemplate template) {
-        return new JmsMessager(template);
+    public Messager messager(RedissonClient client) {
+        return new RedissonMessager(client);
+    }
+
+    @Bean(destroyMethod = "shutdown")
+    public RedissonClient redissonClient(){
+        Config config = new Config();
+        config.useSingleServer().setAddress("redis://127.0.0.1:6379");
+        return Redisson.create(config);
     }
 
     @Bean
@@ -40,10 +48,7 @@ public class WebSocketServerTests {
     }
 
     @Bean
-    public CounterManager counterManager() {
-        Config config = new Config();
-        config.useSingleServer().setAddress("127.0.0.1:6379");
-        RedissonClient client = Redisson.create(config);
+    public CounterManager counterManager(RedissonClient client) {
         return new RedissonCounterManager(client);
     }