Browse Source

优化并发下OAuth2 token重复问题

zhouhao 7 years ago
parent
commit
42508d621d

+ 5 - 0
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/pom.xml

@@ -31,6 +31,11 @@
 
 
     <dependencies>
+        <dependency>
+            <groupId>org.hswebframework.web</groupId>
+            <artifactId>hsweb-concurrent-lock-starter</artifactId>
+            <version>${project.version}</version>
+        </dependency>
         <dependency>
             <groupId>org.hswebframework.web</groupId>
             <artifactId>hsweb-authorization-api</artifactId>

+ 7 - 2
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/src/main/java/org/hswebframework/web/authorization/oauth2/client/OAuth2ClientAutoConfiguration.java

@@ -8,6 +8,7 @@ import org.hswebframework.web.authorization.oauth2.client.simple.*;
 import org.hswebframework.web.authorization.oauth2.client.simple.provider.HswebResponseConvertSupport;
 import org.hswebframework.web.authorization.oauth2.client.simple.provider.HswebResponseJudgeSupport;
 import org.hswebframework.web.authorization.oauth2.client.simple.request.builder.SimpleOAuth2RequestBuilderFactory;
+import org.hswebframework.web.concurrent.lock.LockManager;
 import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
 import org.springframework.boot.context.properties.ConfigurationProperties;
 import org.springframework.context.annotation.Bean;
@@ -42,8 +43,12 @@ public class OAuth2ClientAutoConfiguration {
 
     @ConditionalOnMissingBean(OAuth2RequestService.class)
     @Bean
-    public SimpleOAuth2RequestService simpleOAuth2RequestService(OAuth2ServerConfigRepository configRepository, OAuth2UserTokenRepository userTokenRepository, OAuth2RequestBuilderFactory builderFactory) {
-        return new SimpleOAuth2RequestService(configRepository, userTokenRepository, builderFactory);
+    public SimpleOAuth2RequestService simpleOAuth2RequestService(OAuth2ServerConfigRepository configRepository
+            , OAuth2UserTokenRepository userTokenRepository
+            , OAuth2RequestBuilderFactory builderFactory
+            , LockManager lockManager) {
+
+        return new SimpleOAuth2RequestService(configRepository, userTokenRepository, builderFactory,lockManager);
     }
 
     @ConditionalOnMissingBean(OAuth2ServerConfigRepository.class)

+ 2 - 0
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/src/main/java/org/hswebframework/web/authorization/oauth2/client/request/OAuth2Session.java

@@ -67,6 +67,8 @@ public interface OAuth2Session extends Serializable {
      */
     boolean isClosed();
 
+    AccessTokenInfo requestAccessToken();
+
     AccessTokenInfo getAccessToken();
 
 }

+ 16 - 2
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/src/main/java/org/hswebframework/web/authorization/oauth2/client/simple/SimpleOAuth2RequestService.java

@@ -26,6 +26,7 @@ import org.hswebframework.web.authorization.oauth2.client.OAuth2ServerConfig;
 import org.hswebframework.web.authorization.oauth2.client.OAuth2SessionBuilder;
 import org.hswebframework.web.authorization.oauth2.client.listener.OAuth2Event;
 import org.hswebframework.web.authorization.oauth2.client.listener.OAuth2Listener;
+import org.hswebframework.web.concurrent.lock.LockManager;
 
 import java.util.*;
 
@@ -42,10 +43,22 @@ public class SimpleOAuth2RequestService implements OAuth2RequestService {
 
     private Map<String, Map<Class, List<OAuth2Listener>>> listenerStore = new HashMap<>();
 
-    public SimpleOAuth2RequestService(OAuth2ServerConfigRepository oAuth2ServerConfigService, OAuth2UserTokenRepository oAuth2UserTokenService, OAuth2RequestBuilderFactory oAuth2RequestBuilderFactory) {
+    private LockManager lockManager;
+
+
+    public SimpleOAuth2RequestService(
+            OAuth2ServerConfigRepository oAuth2ServerConfigService
+            , OAuth2UserTokenRepository oAuth2UserTokenService
+            , OAuth2RequestBuilderFactory oAuth2RequestBuilderFactory
+            , LockManager lockManager) {
         this.oAuth2ServerConfigService = oAuth2ServerConfigService;
         this.oAuth2UserTokenService = oAuth2UserTokenService;
         this.oAuth2RequestBuilderFactory = oAuth2RequestBuilderFactory;
+        this.lockManager = lockManager;
+    }
+
+    public void setLockManager(LockManager lockManager) {
+        this.lockManager = lockManager;
     }
 
     @Override
@@ -54,7 +67,8 @@ public class SimpleOAuth2RequestService implements OAuth2RequestService {
         if (null == configEntity || !Byte.valueOf((byte) 1).equals(configEntity.getStatus())) {
             throw new NotFoundException("server not found!");
         }
-        return new SimpleOAuth2SessionBuilder(oAuth2UserTokenService, configEntity, oAuth2RequestBuilderFactory);
+        return new SimpleOAuth2SessionBuilder(oAuth2UserTokenService, configEntity, oAuth2RequestBuilderFactory,
+                lockManager.getReadWriteLock("oauth2-server-lock." + serverId));
     }
 
     @Override

+ 40 - 16
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/src/main/java/org/hswebframework/web/authorization/oauth2/client/simple/SimpleOAuth2SessionBuilder.java

@@ -43,14 +43,17 @@ public class SimpleOAuth2SessionBuilder implements OAuth2SessionBuilder {
 
     private OAuth2RequestBuilderFactory requestBuilderFactory;
 
-    private ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
+    private ReadWriteLock readWriteLock;//.= new ReentrantReadWriteLock();
+
 
     public SimpleOAuth2SessionBuilder(OAuth2UserTokenRepository oAuth2UserTokenRepository,
                                       OAuth2ServerConfig oAuth2ServerConfig,
-                                      OAuth2RequestBuilderFactory requestBuilderFactory) {
+                                      OAuth2RequestBuilderFactory requestBuilderFactory,
+                                      ReadWriteLock readWriteLock) {
         this.oAuth2UserTokenRepository = oAuth2UserTokenRepository;
         this.serverConfig = oAuth2ServerConfig;
         this.requestBuilderFactory = requestBuilderFactory;
+        this.readWriteLock = readWriteLock;
     }
 
     protected String getRealUrl(String url) {
@@ -65,20 +68,15 @@ public class SimpleOAuth2SessionBuilder implements OAuth2SessionBuilder {
 
 
     protected AccessTokenInfo getClientCredentialsToken() {
-        readWriteLock.readLock().lock();
-        try {
-            List<AccessTokenInfo> list = oAuth2UserTokenRepository
-                    .findByServerIdAndGrantType(serverConfig.getId(), GrantType.client_credentials);
-            return list.isEmpty() ? null : list.get(0);
-        } finally {
-            readWriteLock.readLock().unlock();
-        }
+        List<AccessTokenInfo> list = oAuth2UserTokenRepository
+                .findByServerIdAndGrantType(serverConfig.getId(), GrantType.client_credentials);
+        return list.isEmpty() ? null : list.get(0);
     }
 
     protected Consumer<AccessTokenInfo> createOnTokenChanged(Supplier<AccessTokenInfo> tokenGetter, String grantType) {
         return token -> {
-            AccessTokenInfo tokenInfo = tokenGetter.get();
             readWriteLock.writeLock().lock();
+            AccessTokenInfo tokenInfo = tokenGetter.get();
             try {
                 if (tokenInfo != null) {
                     token.setId(tokenInfo.getId());
@@ -111,12 +109,39 @@ public class SimpleOAuth2SessionBuilder implements OAuth2SessionBuilder {
 
     @Override
     public OAuth2Session byClientCredentials() {
-        AccessTokenInfo tokenInfo = getClientCredentialsToken();
         DefaultOAuth2Session session;
-        if (null != tokenInfo) {
-            session = new CachedOAuth2Session(tokenInfo);
+        Supplier<AccessTokenInfo> tokenGetter = () -> {
+            readWriteLock.readLock().lock();
+            try {
+                return getClientCredentialsToken();
+            } finally {
+                readWriteLock.readLock().unlock();
+            }
+        };
+        AccessTokenInfo info = tokenGetter.get();
+        if (null != info) {
+            session = new CachedOAuth2Session(info);
         } else {
-            session = new DefaultOAuth2Session();
+            readWriteLock.writeLock().lock();
+            try {
+                info = getClientCredentialsToken();
+                if (null == info) {
+                    session = new DefaultOAuth2Session();
+                    session.setServerConfig(serverConfig);
+                    session.setRequestBuilderFactory(requestBuilderFactory);
+                    session.onTokenChanged(onClientCredentialsTokenChanged);
+                    session.init();
+                    session.param(OAuth2Constants.grant_type, GrantType.client_credentials);
+                    info = session.requestAccessToken();
+                    info.setGrantType(GrantType.client_credentials);
+                    info.setCreateTime(System.currentTimeMillis());
+                    info.setServerId(serverConfig.getId());
+                    oAuth2UserTokenRepository.insert(info);
+                }
+            } finally {
+                readWriteLock.writeLock().unlock();
+            }
+            session = new CachedOAuth2Session(info);
         }
         session.setServerConfig(serverConfig);
         session.setRequestBuilderFactory(requestBuilderFactory);
@@ -151,5 +176,4 @@ public class SimpleOAuth2SessionBuilder implements OAuth2SessionBuilder {
         return session;
     }
 
-
 }

+ 22 - 35
hsweb-authorization/hsweb-authorization-oauth2/hsweb-authorization-oauth2-client/src/main/java/org/hswebframework/web/authorization/oauth2/client/simple/session/DefaultOAuth2Session.java

@@ -50,10 +50,6 @@ public class DefaultOAuth2Session implements OAuth2Session {
 
     private Consumer<AccessTokenInfo> onTokenChange;
 
-    private boolean authorized = false;
-
-    private ReadWriteLock readWriteLock = new ReentrantReadWriteLock();
-
     public void setRequestBuilderFactory(OAuth2RequestBuilderFactory requestBuilderFactory) {
         this.requestBuilderFactory = requestBuilderFactory;
     }
@@ -108,21 +104,7 @@ public class DefaultOAuth2Session implements OAuth2Session {
 
     @Override
     public OAuth2Session authorize() {
-        readWriteLock.writeLock().lock();
-        if (authorized) {
-            return this;
-        }
-        try {
-            AccessTokenInfo accessTokenInfo = accessTokenRequest
-                    .param(OAuth2Constants.scope, scope)
-                    .post().onError(OAuth2Response.throwOnError)
-                    .as(AccessTokenInfo.class);
-            accessTokenInfo.setCreateTime(System.currentTimeMillis());
-            setAccessTokenInfo(accessTokenInfo);
-            authorized = true;
-        } finally {
-            readWriteLock.writeLock().unlock();
-        }
+        setAccessTokenInfo(requestAccessToken());
         return this;
     }
 
@@ -147,26 +129,31 @@ public class DefaultOAuth2Session implements OAuth2Session {
         return this;
     }
 
+    @Override
+    public AccessTokenInfo requestAccessToken() {
+        AccessTokenInfo accessTokenInfo = accessTokenRequest
+                .param(OAuth2Constants.scope, scope)
+                .post().onError(OAuth2Response.throwOnError)
+                .as(AccessTokenInfo.class);
+        accessTokenInfo.setCreateTime(System.currentTimeMillis());
+        return accessTokenInfo;
+    }
+
     protected void refreshToken() {
         if (accessTokenInfo == null) {
             return;
         }
-        readWriteLock.writeLock().lock();
-        try {
-            OAuth2Request request = createRequest(getRealUrl(serverConfig.getAccessTokenUrl()));
-            applyBasicAuthParam(request);
-            AccessTokenInfo tokenInfo = request
-                    .param(OAuth2Constants.scope, scope)
-                    .param(OAuth2Constants.grant_type, GrantType.refresh_token)
-                    .param(GrantType.refresh_token, accessTokenInfo.getRefreshToken())
-                    .post().onError(OAuth2Response.throwOnError)
-                    .as(AccessTokenInfo.class);
-            tokenInfo.setCreateTime(accessTokenInfo.getCreateTime());
-            tokenInfo.setUpdateTime(System.currentTimeMillis());
-            setAccessTokenInfo(tokenInfo);
-        } finally {
-            readWriteLock.writeLock().unlock();
-        }
+        OAuth2Request request = createRequest(getRealUrl(serverConfig.getAccessTokenUrl()));
+        applyBasicAuthParam(request);
+        AccessTokenInfo tokenInfo = request
+                .param(OAuth2Constants.scope, scope)
+                .param(OAuth2Constants.grant_type, GrantType.refresh_token)
+                .param(GrantType.refresh_token, accessTokenInfo.getRefreshToken())
+                .post().onError(OAuth2Response.throwOnError)
+                .as(AccessTokenInfo.class);
+        tokenInfo.setCreateTime(accessTokenInfo.getCreateTime());
+        tokenInfo.setUpdateTime(System.currentTimeMillis());
+        setAccessTokenInfo(tokenInfo);
     }