zhouhao 7 年之前
父節點
當前提交
2f2ea9ce67

+ 26 - 12
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/src/main/java/org/hswebframework/web/authorization/oauth2/client/simple/SimpleOAuth2SessionBuilder.java

@@ -27,6 +27,8 @@ import org.hswebframework.web.authorization.oauth2.client.simple.session.Default
 import org.hswebframework.web.authorization.oauth2.client.simple.session.PasswordSession;
 
 import java.util.List;
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.function.Consumer;
 import java.util.function.Supplier;
 
@@ -41,6 +43,8 @@ public class SimpleOAuth2SessionBuilder implements OAuth2SessionBuilder {
 
     private OAuth2RequestBuilderFactory requestBuilderFactory;
 
+    private ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
+
     public SimpleOAuth2SessionBuilder(OAuth2UserTokenRepository oAuth2UserTokenRepository,
                                       OAuth2ServerConfig oAuth2ServerConfig,
                                       OAuth2RequestBuilderFactory requestBuilderFactory) {
@@ -61,23 +65,33 @@ public class SimpleOAuth2SessionBuilder implements OAuth2SessionBuilder {
 
 
     protected AccessTokenInfo getClientCredentialsToken() {
-        List<AccessTokenInfo> list = oAuth2UserTokenRepository
-                .findByServerIdAndGrantType(serverConfig.getId(), GrantType.client_credentials);
-        return list.isEmpty() ? null : list.get(0);
+        readWriteLock.readLock().lock();
+        try {
+            List<AccessTokenInfo> list = oAuth2UserTokenRepository
+                    .findByServerIdAndGrantType(serverConfig.getId(), GrantType.client_credentials);
+            return list.isEmpty() ? null : list.get(0);
+        } finally {
+            readWriteLock.readLock().unlock();
+        }
     }
 
     protected Consumer<AccessTokenInfo> createOnTokenChanged(Supplier<AccessTokenInfo> tokenGetter, String grantType) {
         return token -> {
             AccessTokenInfo tokenInfo = tokenGetter.get();
-            if (tokenInfo != null) {
-                token.setId(tokenInfo.getId());
-                tokenInfo.setUpdateTime(System.currentTimeMillis());
-                oAuth2UserTokenRepository.update(tokenInfo.getId(), token);
-            } else {
-                token.setGrantType(grantType);
-                token.setCreateTime(System.currentTimeMillis());
-                token.setServerId(serverConfig.getId());
-                oAuth2UserTokenRepository.insert(token);
+            readWriteLock.writeLock().lock();
+            try {
+                if (tokenInfo != null) {
+                    token.setId(tokenInfo.getId());
+                    tokenInfo.setUpdateTime(System.currentTimeMillis());
+                    oAuth2UserTokenRepository.update(tokenInfo.getId(), token);
+                } else {
+                    token.setGrantType(grantType);
+                    token.setCreateTime(System.currentTimeMillis());
+                    token.setServerId(serverConfig.getId());
+                    oAuth2UserTokenRepository.insert(token);
+                }
+            } finally {
+                readWriteLock.writeLock().unlock();
             }
         };
     }

+ 5 - 0
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/src/main/java/org/hswebframework/web/authorization/oauth2/client/simple/provider/HswebResponseConvertSupport.java

@@ -21,8 +21,10 @@ package org.hswebframework.web.authorization.oauth2.client.simple.provider;
 import com.alibaba.fastjson.JSON;
 import org.hswebframework.web.authorization.Authentication;
 import org.hswebframework.web.authorization.builder.AuthenticationBuilderFactory;
+import org.hswebframework.web.authorization.oauth2.client.exception.OAuth2RequestException;
 import org.hswebframework.web.authorization.oauth2.client.request.definition.ResponseConvertForProviderDefinition;
 import org.hswebframework.web.authorization.oauth2.client.response.OAuth2Response;
+import org.hswebframework.web.oauth2.core.ErrorType;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.stereotype.Component;
 
@@ -42,6 +44,9 @@ public class HswebResponseConvertSupport implements ResponseConvertForProviderDe
     @Override
     public <T> T convert(OAuth2Response response, Class<T> type) {
         String json = response.asString();
+        if (response.status() != 200) {
+            throw new OAuth2RequestException(ErrorType.OTHER, response);
+        }
         if (type == Authentication.class) {
             if (authenticationBuilderFactory != null) {
                 return (T) authenticationBuilderFactory.create().json(json).build();

+ 37 - 17
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/src/main/java/org/hswebframework/web/authorization/oauth2/client/simple/session/DefaultOAuth2Session.java

@@ -25,6 +25,8 @@ import org.hswebframework.web.authorization.oauth2.client.request.OAuth2Session;
 import org.hswebframework.web.authorization.oauth2.client.response.OAuth2Response;
 import org.springframework.util.Assert;
 
+import java.util.concurrent.locks.ReadWriteLock;
+import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.function.Consumer;
 
 import static org.hswebframework.web.authorization.oauth2.client.OAuth2Constants.*;
@@ -48,6 +50,10 @@ public class DefaultOAuth2Session implements OAuth2Session {
 
     private Consumer<AccessTokenInfo> onTokenChange;
 
+    private boolean authorized = false;
+
+    private ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
+
     public void setRequestBuilderFactory(OAuth2RequestBuilderFactory requestBuilderFactory) {
         this.requestBuilderFactory = requestBuilderFactory;
     }
@@ -102,12 +108,21 @@ public class DefaultOAuth2Session implements OAuth2Session {
 
     @Override
     public OAuth2Session authorize() {
-        AccessTokenInfo accessTokenInfo = accessTokenRequest
-                .param(OAuth2Constants.scope, scope)
-                .post().onError(OAuth2Response.throwOnError)
-                .as(AccessTokenInfo.class);
-        accessTokenInfo.setCreateTime(System.currentTimeMillis());
-        setAccessTokenInfo(accessTokenInfo);
+        readWriteLock.writeLock().lock();
+        if (authorized) {
+            return this;
+        }
+        try {
+            AccessTokenInfo accessTokenInfo = accessTokenRequest
+                    .param(OAuth2Constants.scope, scope)
+                    .post().onError(OAuth2Response.throwOnError)
+                    .as(AccessTokenInfo.class);
+            accessTokenInfo.setCreateTime(System.currentTimeMillis());
+            setAccessTokenInfo(accessTokenInfo);
+            authorized = true;
+        } finally {
+            readWriteLock.writeLock().unlock();
+        }
         return this;
     }
 
@@ -136,17 +151,22 @@ public class DefaultOAuth2Session implements OAuth2Session {
         if (accessTokenInfo == null) {
             return;
         }
-        OAuth2Request request = createRequest(getRealUrl(serverConfig.getAccessTokenUrl()));
-        applyBasicAuthParam(request);
-        AccessTokenInfo tokenInfo = request
-                .param(OAuth2Constants.scope, scope)
-                .param(OAuth2Constants.grant_type, GrantType.refresh_token)
-                .param(GrantType.refresh_token, accessTokenInfo.getRefreshToken())
-                .post().onError(OAuth2Response.throwOnError)
-                .as(AccessTokenInfo.class);
-        tokenInfo.setCreateTime(accessTokenInfo.getCreateTime());
-        tokenInfo.setUpdateTime(System.currentTimeMillis());
-        setAccessTokenInfo(tokenInfo);
+        readWriteLock.writeLock().lock();
+        try {
+            OAuth2Request request = createRequest(getRealUrl(serverConfig.getAccessTokenUrl()));
+            applyBasicAuthParam(request);
+            AccessTokenInfo tokenInfo = request
+                    .param(OAuth2Constants.scope, scope)
+                    .param(OAuth2Constants.grant_type, GrantType.refresh_token)
+                    .param(GrantType.refresh_token, accessTokenInfo.getRefreshToken())
+                    .post().onError(OAuth2Response.throwOnError)
+                    .as(AccessTokenInfo.class);
+            tokenInfo.setCreateTime(accessTokenInfo.getCreateTime());
+            tokenInfo.setUpdateTime(System.currentTimeMillis());
+            setAccessTokenInfo(tokenInfo);
+        } finally {
+            readWriteLock.writeLock().unlock();
+        }
     }