Browse Source

优化clientId和secure获取逻辑

zhou-hao 3 years ago
parent
commit
2c8f12dffa

+ 26 - 8
hsweb-authorization/hsweb-authorization-oauth2/src/main/java/org/hswebframework/web/oauth2/server/web/OAuth2AuthorizeController.java

@@ -6,6 +6,7 @@ import io.swagger.v3.oas.annotations.media.Schema;
 import io.swagger.v3.oas.annotations.tags.Tag;
 import lombok.AllArgsConstructor;
 import lombok.SneakyThrows;
+import org.apache.commons.codec.binary.Base64;
 import org.hswebframework.web.authorization.Authentication;
 import org.hswebframework.web.authorization.annotation.Authorize;
 import org.hswebframework.web.authorization.exception.UnAuthorizedException;
@@ -19,14 +20,18 @@ import org.hswebframework.web.oauth2.server.code.AuthorizationCodeRequest;
 import org.hswebframework.web.oauth2.server.code.AuthorizationCodeTokenRequest;
 import org.hswebframework.web.oauth2.server.credential.ClientCredentialRequest;
 import org.hswebframework.web.oauth2.server.refresh.RefreshTokenRequest;
+import org.springframework.http.HttpHeaders;
 import org.springframework.http.MediaType;
 import org.springframework.http.ResponseEntity;
 import org.springframework.util.MultiValueMap;
 import org.springframework.web.bind.annotation.*;
 import org.springframework.web.server.ServerWebExchange;
 import reactor.core.publisher.Mono;
+import reactor.util.function.Tuple2;
+import reactor.util.function.Tuples;
 
 import java.net.URLEncoder;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Map;
 import java.util.Optional;
@@ -84,10 +89,10 @@ public class OAuth2AuthorizeController {
             @RequestParam("grant_type") GrantType grantType,
             ServerWebExchange exchange) {
         Map<String, String> params = exchange.getRequest().getQueryParams().toSingleValueMap();
-
+        Tuple2<String,String> clientIdAndSecret = getClientIdAndClientSecret(params,exchange);
         return this
-                .getOAuth2Client(params.get("client_id"))
-                .doOnNext(client -> client.validateSecret(params.get("client_secret")))
+                .getOAuth2Client(clientIdAndSecret.getT1())
+                .doOnNext(client -> client.validateSecret(clientIdAndSecret.getT2()))
                 .flatMap(client -> grantType.requestToken(oAuth2GrantService, client, new HashMap<>(params)))
                 .map(ResponseEntity::ok);
     }
@@ -106,15 +111,28 @@ public class OAuth2AuthorizeController {
                 .getFormData()
                 .map(MultiValueMap::toSingleValueMap)
                 .flatMap(params -> {
+                    Tuple2<String,String> clientIdAndSecret = getClientIdAndClientSecret(params,exchange);
                     GrantType grantType = GrantType.of(params.get("grant_type"));
                     return this
-                            .getOAuth2Client(params.get("client_id"))
-                            .doOnNext(client -> client.validateSecret(params.get("client_secret")))
+                            .getOAuth2Client(clientIdAndSecret.getT1())
+                            .doOnNext(client -> client.validateSecret(clientIdAndSecret.getT2()))
                             .flatMap(client -> grantType.requestToken(oAuth2GrantService, client, new HashMap<>(params)))
                             .map(ResponseEntity::ok);
                 });
     }
 
+    private Tuple2<String, String> getClientIdAndClientSecret(Map<String, String> params, ServerWebExchange exchange) {
+        String authorization = exchange.getRequest().getHeaders().getFirst(HttpHeaders.AUTHORIZATION);
+        if (authorization != null && authorization.startsWith("Basic ")) {
+            String[] arr = new String(Base64.decodeBase64(authorization.substring(5))).split(":");
+            if (arr.length >= 2) {
+                return Tuples.of(arr[0], arr[1]);
+            }
+            return Tuples.of(arr[0], arr[0]);
+        }
+        return Tuples.of(params.getOrDefault("client_id",""),params.getOrDefault("client_secret",""));
+    }
+
     public enum GrantType {
         authorization_code {
             @Override
@@ -132,7 +150,7 @@ public class OAuth2AuthorizeController {
                         .requestToken(new ClientCredentialRequest(client, param));
             }
         },
-        refresh_token{
+        refresh_token {
             @Override
             Mono<AccessToken> requestToken(OAuth2GrantService service, OAuth2Client client, Map<String, String> param) {
                 return service
@@ -143,10 +161,10 @@ public class OAuth2AuthorizeController {
 
         abstract Mono<AccessToken> requestToken(OAuth2GrantService service, OAuth2Client client, Map<String, String> param);
 
-        static GrantType of(String name){
+        static GrantType of(String name) {
             try {
                 return GrantType.valueOf(name);
-            }catch (Throwable e){
+            } catch (Throwable e) {
                 throw new OAuth2Exception(ErrorType.UNSUPPORTED_GRANT_TYPE);
             }
         }