Skip to content

Commit 7e55bf7

Browse files
authored
Add "warnings" field to FGA query and check response (#426)
* Add "warnings" field to FGA query and check response * Remove print statement * Change GetWarning to GetMessage * Refactor warnings container to move unmarshing logic into * Fix breaking change
1 parent a993080 commit 7e55bf7

3 files changed

Lines changed: 322 additions & 5 deletions

File tree

pkg/fga/client.go

Lines changed: 115 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -379,10 +379,96 @@ type CheckBatchOpts struct {
379379
WarrantToken string `json:"-"`
380380
}
381381

382+
type Warning interface {
383+
Warning() string
384+
}
385+
386+
type BaseWarning struct {
387+
Code string `json:"code"`
388+
Message string `json:"message"`
389+
}
390+
391+
func (b BaseWarning) Warning() string { return fmt.Sprintf("%s: %s", b.Code, b.Message) }
392+
393+
type MissingContextKeysWarning struct {
394+
BaseWarning
395+
Keys []string `json:"keys"`
396+
}
397+
398+
func (m MissingContextKeysWarning) Warning() string {
399+
return fmt.Sprintf("%s: %s [%s]", m.Code, m.Message, strings.Join(m.Keys, ", "))
400+
}
401+
402+
type ConvertSchemaWarning struct {
403+
BaseWarning
404+
}
405+
406+
var warningRegistry = map[string]func() Warning{
407+
"missing_context_keys": func() Warning { return &MissingContextKeysWarning{} },
408+
"validation_warning": func() Warning { return &ConvertSchemaWarning{} },
409+
}
410+
411+
func unmarshalWarnings(raw json.RawMessage) ([]Warning, error) {
412+
var rawList []json.RawMessage
413+
if err := json.Unmarshal(raw, &rawList); err != nil {
414+
return nil, fmt.Errorf("unmarshaling warnings list: %s", err.Error())
415+
}
416+
417+
var warnings []Warning
418+
for _, rawItem := range rawList {
419+
var rawWarning struct {
420+
Code string `json:"code"`
421+
}
422+
if err := json.Unmarshal(rawItem, &rawWarning); err != nil {
423+
return nil, fmt.Errorf("extracting warning code: %s", err.Error())
424+
}
425+
426+
var warning Warning
427+
if constructor, ok := warningRegistry[rawWarning.Code]; ok {
428+
warning = constructor()
429+
} else {
430+
warning = &BaseWarning{}
431+
}
432+
433+
if err := json.Unmarshal(rawItem, warning); err != nil {
434+
return nil, fmt.Errorf("decoding warning: %s", err.Error())
435+
}
436+
437+
warnings = append(warnings, warning)
438+
}
439+
440+
return warnings, nil
441+
}
442+
443+
func (checkResponse *CheckResponse) UnmarshalJSON(data []byte) error {
444+
type Alias CheckResponse
445+
var raw struct {
446+
Alias
447+
Warnings json.RawMessage `json:"warnings,omitempty"`
448+
}
449+
450+
if err := json.Unmarshal(data, &raw); err != nil {
451+
return err
452+
}
453+
454+
*checkResponse = CheckResponse(raw.Alias)
455+
456+
if len(raw.Warnings) > 0 {
457+
warnings, err := unmarshalWarnings(raw.Warnings)
458+
if err != nil {
459+
return err
460+
}
461+
checkResponse.Warnings = warnings
462+
}
463+
464+
return nil
465+
}
466+
382467
type CheckResponse struct {
383468
Result string `json:"result"`
384469
IsImplicit bool `json:"is_implicit"`
385470
DebugInfo DebugInfo `json:"debug_info,omitempty"`
471+
Warnings []Warning `json:"warnings,omitempty"`
386472
}
387473

388474
func (checkResponse CheckResponse) Authorized() bool {
@@ -403,6 +489,7 @@ type DecisionTreeNode struct {
403489
}
404490

405491
// Query
492+
406493
type QueryOpts struct {
407494
// Query to be executed.
408495
Query string `url:"q"`
@@ -452,19 +539,42 @@ type QueryResponse struct {
452539

453540
// Cursor pagination options.
454541
ListMetadata common.ListMetadata `json:"list_metadata"`
542+
543+
// Warnings generated from query issues.
544+
Warnings []Warning `json:"warnings,omitempty"`
545+
}
546+
547+
func (queryResponse *QueryResponse) UnmarshalJSON(data []byte) error {
548+
type Alias QueryResponse
549+
var raw struct {
550+
Alias
551+
Warnings json.RawMessage `json:"warnings,omitempty"`
552+
}
553+
554+
if err := json.Unmarshal(data, &raw); err != nil {
555+
return err
556+
}
557+
558+
*queryResponse = QueryResponse(raw.Alias)
559+
560+
if len(raw.Warnings) > 0 {
561+
warnings, err := unmarshalWarnings(raw.Warnings)
562+
if err != nil {
563+
return err
564+
}
565+
queryResponse.Warnings = warnings
566+
}
567+
568+
return nil
455569
}
456570

457571
// Schema
572+
458573
type ConvertSchemaToResourceTypesOpts struct {
459574
// The schema to convert to resource types.
460575
Schema string
461576
}
462577

463-
type ConvertSchemaWarning struct {
464-
// The warning message.
465-
Message string `json:"message"`
466-
}
467-
468578
type ConvertResourceTypesToSchemaOpts struct {
469579
// The version of the transpiler to use.
470580
Version string `json:"version"`

pkg/fga/client_test.go

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1098,6 +1098,57 @@ func checkTestHandler(w http.ResponseWriter, r *http.Request) {
10981098
w.Write(body)
10991099
}
11001100

1101+
func checkTestHandlerWarnings(w http.ResponseWriter, r *http.Request) {
1102+
auth := r.Header.Get("Authorization")
1103+
if auth != "Bearer test" {
1104+
http.Error(w, "bad auth", http.StatusUnauthorized)
1105+
return
1106+
}
1107+
1108+
if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") {
1109+
w.WriteHeader(http.StatusBadRequest)
1110+
return
1111+
}
1112+
1113+
// Create concrete warnings and wrap them
1114+
warnings := []Warning{
1115+
&MissingContextKeysWarning{
1116+
BaseWarning: BaseWarning{
1117+
Code: "missing_context_keys",
1118+
Message: "Some context keys were not provided.",
1119+
},
1120+
Keys: []string{"user_id", "org_id"},
1121+
},
1122+
1123+
&BaseWarning{
1124+
Code: "unknown",
1125+
Message: "Unknown warning occurred.",
1126+
},
1127+
1128+
&ConvertSchemaWarning{
1129+
BaseWarning: BaseWarning{
1130+
Code: "validation_warning",
1131+
Message: "Schema validation produced a warning.",
1132+
},
1133+
},
1134+
}
1135+
1136+
body, err := json.Marshal(
1137+
CheckResponse{
1138+
Result: CheckResultAuthorized,
1139+
IsImplicit: false,
1140+
Warnings: warnings,
1141+
})
1142+
1143+
if err != nil {
1144+
w.WriteHeader(http.StatusInternalServerError)
1145+
return
1146+
}
1147+
1148+
w.WriteHeader(http.StatusOK)
1149+
w.Write(body)
1150+
}
1151+
11011152
func TestCheckBatch(t *testing.T) {
11021153
tests := []struct {
11031154
scenario string
@@ -1317,6 +1368,75 @@ func queryTestHandler(w http.ResponseWriter, r *http.Request) {
13171368
w.Write(body)
13181369
}
13191370

1371+
func queryTestHandlerWarnings(w http.ResponseWriter, r *http.Request) {
1372+
auth := r.Header.Get("Authorization")
1373+
if auth != "Bearer test" {
1374+
http.Error(w, "bad auth", http.StatusUnauthorized)
1375+
return
1376+
}
1377+
1378+
if userAgent := r.Header.Get("User-Agent"); !strings.Contains(userAgent, "workos-go/") {
1379+
w.WriteHeader(http.StatusBadRequest)
1380+
return
1381+
}
1382+
1383+
// Create concrete warnings and wrap them
1384+
warnings := []Warning{
1385+
&MissingContextKeysWarning{
1386+
BaseWarning: BaseWarning{
1387+
Code: "missing_context_keys",
1388+
Message: "Some context keys were not provided.",
1389+
},
1390+
Keys: []string{"user_id", "org_id"},
1391+
},
1392+
&BaseWarning{
1393+
Code: "unknown",
1394+
Message: "Unknown warning occurred.",
1395+
},
1396+
&ConvertSchemaWarning{
1397+
BaseWarning: BaseWarning{
1398+
Code: "validation_warning",
1399+
Message: "Schema validation produced a warning.",
1400+
},
1401+
},
1402+
}
1403+
1404+
body, err := json.Marshal(struct {
1405+
QueryResponse
1406+
}{
1407+
QueryResponse: QueryResponse{
1408+
Data: []QueryResult{
1409+
{
1410+
ResourceType: "role",
1411+
ResourceId: "role_01SXW182",
1412+
Relation: "member",
1413+
Warrant: Warrant{
1414+
ResourceType: "role",
1415+
ResourceId: "role_01SXW182",
1416+
Relation: "member",
1417+
Subject: Subject{
1418+
ResourceType: "user",
1419+
ResourceId: "user_01SXW182",
1420+
},
1421+
},
1422+
},
1423+
},
1424+
ListMetadata: common.ListMetadata{
1425+
Before: "",
1426+
After: "",
1427+
},
1428+
Warnings: warnings,
1429+
},
1430+
})
1431+
if err != nil {
1432+
w.WriteHeader(http.StatusInternalServerError)
1433+
return
1434+
}
1435+
1436+
w.WriteHeader(http.StatusOK)
1437+
w.Write(body)
1438+
}
1439+
13201440
func convertSchemaToResourceTypesTestHandler(w http.ResponseWriter, r *http.Request) {
13211441
auth := r.Header.Get("Authorization")
13221442
if auth != "Bearer test" {

pkg/fga/fga_test.go

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"net/http"
66
"net/http/httptest"
7+
"sort"
78
"testing"
89

910
"github.com/stretchr/testify/require"
@@ -371,6 +372,54 @@ func TestFGACheck(t *testing.T) {
371372
require.True(t, checkResponse.Authorized())
372373
}
373374

375+
func TestFGACheckWithWarnings(t *testing.T) {
376+
server := httptest.NewServer(http.HandlerFunc(checkTestHandlerWarnings))
377+
defer server.Close()
378+
379+
DefaultClient = &Client{
380+
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
381+
Endpoint: server.URL,
382+
}
383+
SetAPIKey("test")
384+
385+
checkResponse, err := Check(context.Background(), CheckOpts{
386+
Checks: []WarrantCheck{
387+
{
388+
ResourceType: "report",
389+
ResourceId: "ljc_1029",
390+
Relation: "member",
391+
Subject: Subject{
392+
ResourceType: "user",
393+
ResourceId: "user_01SXW182",
394+
},
395+
},
396+
},
397+
})
398+
399+
require.NoError(t, err)
400+
require.Len(t, checkResponse.Warnings, 3)
401+
402+
sort.Slice(checkResponse.Warnings, func(i, j int) bool {
403+
return checkResponse.Warnings[i].Warning() < checkResponse.Warnings[j].Warning()
404+
})
405+
406+
first := checkResponse.Warnings[0]
407+
second := checkResponse.Warnings[1]
408+
third := checkResponse.Warnings[2]
409+
410+
mw, ok := first.(*MissingContextKeysWarning)
411+
require.True(t, ok)
412+
require.ElementsMatch(t, mw.Keys, []string{"user_id", "org_id"})
413+
414+
bw, ok := second.(*BaseWarning)
415+
require.True(t, ok)
416+
require.Equal(t, "unknown", bw.Code)
417+
418+
cw, ok := third.(*ConvertSchemaWarning)
419+
require.True(t, ok)
420+
require.Equal(t, "validation_warning", cw.Code)
421+
}
422+
374423
func TestFGACheckBatch(t *testing.T) {
375424
server := httptest.NewServer(http.HandlerFunc(checkBatchTestHandler))
376425
defer server.Close()
@@ -441,6 +490,44 @@ func TestFGAQuery(t *testing.T) {
441490
require.Equal(t, expectedResponse, queryResponse)
442491
}
443492

493+
func TestFGAQueryWithWarnings(t *testing.T) {
494+
server := httptest.NewServer(http.HandlerFunc(queryTestHandlerWarnings))
495+
defer server.Close()
496+
497+
DefaultClient = &Client{
498+
HTTPClient: &retryablehttp.HttpClient{Client: *server.Client()},
499+
Endpoint: server.URL,
500+
}
501+
SetAPIKey("test")
502+
503+
queryResponse, err := Query(context.Background(), QueryOpts{
504+
Query: "select role where user:user_01SXW182 is member",
505+
})
506+
507+
require.NoError(t, err)
508+
require.Len(t, queryResponse.Warnings, 3)
509+
510+
sort.Slice(queryResponse.Warnings, func(i, j int) bool {
511+
return queryResponse.Warnings[i].Warning() < queryResponse.Warnings[j].Warning()
512+
})
513+
514+
first := queryResponse.Warnings[0]
515+
second := queryResponse.Warnings[1]
516+
third := queryResponse.Warnings[2]
517+
518+
mw, ok := first.(*MissingContextKeysWarning)
519+
require.True(t, ok)
520+
require.ElementsMatch(t, mw.Keys, []string{"user_id", "org_id"})
521+
522+
bw, ok := second.(*BaseWarning)
523+
require.True(t, ok)
524+
require.Equal(t, "unknown", bw.Code)
525+
526+
cw, ok := third.(*ConvertSchemaWarning)
527+
require.True(t, ok)
528+
require.Equal(t, "validation_warning", cw.Code)
529+
}
530+
444531
func TestFGAConvertSchemaToResourceTypes(t *testing.T) {
445532
server := httptest.NewServer(http.HandlerFunc(convertSchemaToResourceTypesTestHandler))
446533
defer server.Close()

0 commit comments

Comments
 (0)