11use anyhow:: Result ;
22use serde:: Deserialize ;
3- use std:: net:: SocketAddr ;
3+ use std:: { collections :: HashMap , net:: SocketAddr } ;
44
55use axum:: {
6- extract:: { ConnectInfo , Form , Json , Query } ,
6+ extract:: { ConnectInfo , Form , Json , Multipart , Query } ,
77 http:: StatusCode ,
88 response:: IntoResponse ,
99} ;
@@ -14,50 +14,107 @@ use _functions::functions::system::oauth::{
1414} ;
1515
1616#[ derive( Deserialize ) ]
17- #[ serde( rename_all = "snake_case" , tag = "type" ) ]
18- pub enum LoginParams {
19- Password { username : String , password : String } ,
20- ClientCredentials { scope : String } ,
21- RefreshToken { refresh_token : String } ,
17+ pub struct LoginQuery {
18+ grant_type : Option < LoginQueryType > ,
19+ // ClientCredentials
20+ scope : Option < String > ,
21+ // RefreshToken
22+ refresh_token : Option < String > ,
23+ }
24+
25+ #[ derive( Deserialize ) ]
26+ #[ serde( rename_all = "snake_case" , untagged) ]
27+ pub enum LoginQueryType {
28+ // 这里留着只是用于标记,Password 这个项是故意忽略的,因为这个类型只能手动解析
29+ // Password,
30+ ClientCredentials ,
31+ RefreshToken ,
2232}
2333
2434#[ tracing:: instrument( skip_all) ]
2535pub async fn oauth (
2636 ConnectInfo ( native_ip) : ConnectInfo < SocketAddr > ,
2737 ExtractIP ( ip) : ExtractIP ,
2838 ExtractUserAgent ( user_agent) : ExtractUserAgent ,
29- Query ( query) : Query < Option < LoginParams > > ,
30- Form ( form) : Form < Option < LoginParams > > ,
39+ Query ( query) : Query < LoginQuery > ,
40+ mut form : Multipart ,
3141) -> Result < impl IntoResponse , ( StatusCode , String ) > {
3242 let ip = ip. unwrap_or ( native_ip) ;
3343
34- if let Some ( form ) = form {
35- match form {
36- LoginParams :: Password { username , password } => Ok ( Json (
37- oauth_password_login ( username , password )
38- . await
39- . map_err ( |err| ( StatusCode :: INTERNAL_SERVER_ERROR , err. to_string ( ) ) ) ? ,
44+ let mut form_fields = {
45+ let mut ret = HashMap :: new ( ) ;
46+ while let Some ( field ) = form . next_field ( ) . await . map_err ( |err| {
47+ (
48+ StatusCode :: BAD_REQUEST ,
49+ format ! ( "Failed to read form field: {}" , err) ,
4050 )
41- . into_response ( ) ) ,
42- _ => return Err ( ( StatusCode :: BAD_REQUEST , "Invalid grant type" . into ( ) ) ) ,
51+ } ) ? {
52+ let name = field
53+ . name ( )
54+ . ok_or ( ( StatusCode :: BAD_REQUEST , "Field name is required" . into ( ) ) ) ?
55+ . to_string ( ) ;
56+ let value = field. text ( ) . await . map_err ( |err| {
57+ (
58+ StatusCode :: BAD_REQUEST ,
59+ format ! ( "Failed to read form field {}: {}" , name, err) ,
60+ )
61+ } ) ?;
62+ ret. insert ( name, value) ;
63+ }
64+ ret
65+ } ;
66+ if !form_fields. is_empty ( ) {
67+ let grant_type = form_fields
68+ . remove ( "grant_type" )
69+ . ok_or ( ( StatusCode :: BAD_REQUEST , "Grant type is required" . into ( ) ) ) ?;
70+ if grant_type != "password" {
71+ return Err ( ( StatusCode :: BAD_REQUEST , "Invalid grant type" . into ( ) ) ) ;
4372 }
44- } else if let Some ( query) = query {
45- match query {
46- LoginParams :: ClientCredentials { scope } => Ok ( Json (
73+
74+ let username = form_fields
75+ . remove ( "username" )
76+ . ok_or ( ( StatusCode :: BAD_REQUEST , "Username is required" . into ( ) ) ) ?;
77+ let password = form_fields
78+ . remove ( "password" )
79+ . ok_or ( ( StatusCode :: BAD_REQUEST , "Password is required" . into ( ) ) ) ?;
80+ return Ok ( Json (
81+ oauth_password_login ( username, password)
82+ . await
83+ . map_err ( |err| ( StatusCode :: INTERNAL_SERVER_ERROR , err. to_string ( ) ) ) ?,
84+ )
85+ . into_response ( ) ) ;
86+ }
87+
88+ match query
89+ . grant_type
90+ . ok_or ( ( StatusCode :: BAD_REQUEST , "Grant type is required" . into ( ) ) ) ?
91+ {
92+ LoginQueryType :: ClientCredentials => {
93+ let scope = query. scope . ok_or_else ( || {
94+ (
95+ StatusCode :: BAD_REQUEST ,
96+ "Scope is required for client credentials" . into ( ) ,
97+ )
98+ } ) ?;
99+ return Ok ( Json (
47100 oauth_client_credentials ( scope)
48101 . await
49102 . map_err ( |err| ( StatusCode :: INTERNAL_SERVER_ERROR , err. to_string ( ) ) ) ?,
50103 )
51- . into_response ( ) ) ,
52- LoginParams :: RefreshToken { refresh_token } => {
53- oauth_refresh ( refresh_token)
54- . await
55- . map_err ( |err| ( StatusCode :: INTERNAL_SERVER_ERROR , err. to_string ( ) ) ) ?;
56- Ok ( ( ) . into_response ( ) )
57- }
58- _ => return Err ( ( StatusCode :: BAD_REQUEST , "Invalid grant type" . into ( ) ) ) ,
104+ . into_response ( ) ) ;
105+ }
106+ LoginQueryType :: RefreshToken => {
107+ let refresh_token = query. refresh_token . ok_or_else ( || {
108+ (
109+ StatusCode :: BAD_REQUEST ,
110+ "Refresh token is required for refresh token grant type" . into ( ) ,
111+ )
112+ } ) ?;
113+ oauth_refresh ( refresh_token)
114+ . await
115+ . map_err ( |err| ( StatusCode :: INTERNAL_SERVER_ERROR , err. to_string ( ) ) ) ?;
116+ return Ok ( StatusCode :: NO_CONTENT . into_response ( ) ) ;
59117 }
60- } else {
61- return Err ( ( StatusCode :: BAD_REQUEST , "No parameters provided" . into ( ) ) ) ;
118+ _ => return Err ( ( StatusCode :: BAD_REQUEST , "Invalid grant type" . into ( ) ) ) ,
62119 }
63120}
0 commit comments