@@ -31,6 +31,8 @@ import (
3131 "github.com/ory/kratos/identity"
3232 "github.com/ory/kratos/pkg"
3333 "github.com/ory/kratos/pkg/testhelpers"
34+ "github.com/ory/kratos/selfservice/flow"
35+ "github.com/ory/kratos/selfservice/flow/registration"
3436 . "github.com/ory/kratos/session"
3537 "github.com/ory/kratos/x"
3638 "github.com/ory/kratos/x/nosurfx"
@@ -1096,6 +1098,104 @@ func TestHandlerRefreshSessionBySessionID(t *testing.T) {
10961098 })
10971099}
10981100
1101+ func TestExchangeCode (t * testing.T ) {
1102+ t .Parallel ()
1103+
1104+ conf , reg := pkg .NewFastRegistryWithMocks (t ,
1105+ configx .WithValues (testhelpers .DefaultIdentitySchemaConfig ("file://./stub/identity.schema.json" )),
1106+ )
1107+ ts , _ , _ , _ := testhelpers .NewKratosServerWithCSRFAndRouters (t , reg )
1108+ ctx := context .Background ()
1109+
1110+ newRegistrationFlow := func (t * testing.T ) * registration.Flow {
1111+ t .Helper ()
1112+ req := & http.Request {URL : urlx .ParseOrPanic ("/" )}
1113+ f , err := registration .NewFlow (conf , time .Minute , "csrf_token" , req , flow .TypeAPI )
1114+ require .NoError (t , err )
1115+ require .NoError (t , reg .RegistrationFlowPersister ().CreateRegistrationFlow (ctx , f ))
1116+ return f
1117+ }
1118+
1119+ exchangeURL := func (initCode , returnToCode string ) string {
1120+ return fmt .Sprintf ("%s/sessions/token-exchange?init_code=%s&return_to_code=%s" , ts .URL , initCode , returnToCode )
1121+ }
1122+
1123+ t .Run ("case=returns 400 when codes are missing" , func (t * testing.T ) {
1124+ t .Parallel ()
1125+ res , err := ts .Client ().Get (ts .URL + "/sessions/token-exchange" )
1126+ require .NoError (t , err )
1127+ defer func () { _ = res .Body .Close () }()
1128+ assert .Equal (t , http .StatusBadRequest , res .StatusCode )
1129+ })
1130+
1131+ t .Run ("case=returns 404 for invalid codes" , func (t * testing.T ) {
1132+ t .Parallel ()
1133+ res , err := ts .Client ().Get (exchangeURL ("invalid_init" , "invalid_return" ))
1134+ require .NoError (t , err )
1135+ defer func () { _ = res .Body .Close () }()
1136+ assert .Equal (t , http .StatusNotFound , res .StatusCode )
1137+ })
1138+
1139+ t .Run ("case=returns 422 with flow when exchanger exists but has no session" , func (t * testing.T ) {
1140+ t .Parallel ()
1141+ f := newRegistrationFlow (t )
1142+
1143+ e , err := reg .SessionTokenExchangePersister ().CreateSessionTokenExchanger (ctx , f .ID )
1144+ require .NoError (t , err )
1145+
1146+ res , err := ts .Client ().Get (exchangeURL (e .InitCode , e .ReturnToCode ))
1147+ require .NoError (t , err )
1148+ defer func () { _ = res .Body .Close () }()
1149+
1150+ assert .Equal (t , http .StatusUnprocessableEntity , res .StatusCode )
1151+
1152+ body , err := io .ReadAll (res .Body )
1153+ require .NoError (t , err )
1154+ assert .Equal (t , f .ID .String (), gjson .GetBytes (body , "id" ).String ())
1155+ })
1156+
1157+ t .Run ("case=returns 404 when exchanger exists but flow was deleted" , func (t * testing.T ) {
1158+ t .Parallel ()
1159+ // Create an exchanger with a flow ID that has no corresponding persisted flow.
1160+ orphanFlowID := uuid .Must (uuid .NewV4 ())
1161+ e , err := reg .SessionTokenExchangePersister ().CreateSessionTokenExchanger (ctx , orphanFlowID )
1162+ require .NoError (t , err )
1163+
1164+ res , err := ts .Client ().Get (exchangeURL (e .InitCode , e .ReturnToCode ))
1165+ require .NoError (t , err )
1166+ defer func () { _ = res .Body .Close () }()
1167+
1168+ assert .Equal (t , http .StatusNotFound , res .StatusCode )
1169+ })
1170+
1171+ t .Run ("case=returns 200 with session when exchanger has a session" , func (t * testing.T ) {
1172+ t .Parallel ()
1173+ f := newRegistrationFlow (t )
1174+
1175+ e , err := reg .SessionTokenExchangePersister ().CreateSessionTokenExchanger (ctx , f .ID )
1176+ require .NoError (t , err )
1177+
1178+ i := identity .NewIdentity ("" )
1179+ require .NoError (t , reg .IdentityManager ().Create (ctx , i ))
1180+ req := & http.Request {URL : urlx .ParseOrPanic ("/" )}
1181+ sess , err := testhelpers .NewActiveSession (req , reg , i , time .Now (), identity .CredentialsTypePassword , identity .AuthenticatorAssuranceLevel1 )
1182+ require .NoError (t , err )
1183+ require .NoError (t , reg .SessionPersister ().UpsertSession (ctx , sess ))
1184+ require .NoError (t , reg .SessionTokenExchangePersister ().UpdateSessionOnExchanger (ctx , f .ID , sess .ID ))
1185+
1186+ res , err := ts .Client ().Get (exchangeURL (e .InitCode , e .ReturnToCode ))
1187+ require .NoError (t , err )
1188+ defer func () { _ = res .Body .Close () }()
1189+
1190+ assert .Equal (t , http .StatusOK , res .StatusCode )
1191+
1192+ body , err := io .ReadAll (res .Body )
1193+ require .NoError (t , err )
1194+ assert .NotEmpty (t , gjson .GetBytes (body , "session_token" ).String ())
1195+ assert .Equal (t , sess .ID .String (), gjson .GetBytes (body , "session.id" ).String ())
1196+ })
1197+ }
1198+
10991199type byCreatedAt []Session
11001200
11011201func (s byCreatedAt ) Len () int { return len (s ) }
0 commit comments