Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/docs/openapi3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ var ErrURINotSupported = errors.New("unsupported URI")
ErrURINotSupported indicates the ReadFromURIFunc does not know how to handle
a given URI.

var IncludeOrigin = false
IncludeOrigin specifies whether to include the origin of the OpenAPI
elements Set this to true before loading a spec to include the origin of the
OpenAPI elements Note it is global and affects all loaders


FUNCTIONS

Expand Down Expand Up @@ -790,9 +795,6 @@ type Loader struct {
// IsExternalRefsAllowed enables visiting other files
IsExternalRefsAllowed bool

// IncludeOrigin specifies whether to include the origin of the OpenAPI elements
IncludeOrigin bool

// ReadFromURIFunc allows overriding the any file/URL reading func
ReadFromURIFunc ReadFromURIFunc

Expand Down
16 changes: 9 additions & 7 deletions openapi3/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ import (
"strings"
)

// IncludeOrigin specifies whether to include the origin of the OpenAPI elements
// Set this to true before loading a spec to include the origin of the OpenAPI elements
// Note it is global and affects all loaders
var IncludeOrigin = false

func foundUnresolvedRef(ref string) error {
return fmt.Errorf("found unresolved ref: %q", ref)
}
Expand All @@ -28,9 +33,6 @@ type Loader struct {
// IsExternalRefsAllowed enables visiting other files
IsExternalRefsAllowed bool

// IncludeOrigin specifies whether to include the origin of the OpenAPI elements
IncludeOrigin bool

// ReadFromURIFunc allows overriding the any file/URL reading func
ReadFromURIFunc ReadFromURIFunc

Expand Down Expand Up @@ -106,7 +108,7 @@ func (loader *Loader) loadSingleElementFromURI(ref string, rootPath *url.URL, el
if err != nil {
return nil, err
}
if err := unmarshal(data, element, loader.IncludeOrigin); err != nil {
if err := unmarshal(data, element, IncludeOrigin); err != nil {
return nil, err
}

Expand Down Expand Up @@ -142,7 +144,7 @@ func (loader *Loader) LoadFromIoReader(reader io.Reader) (*T, error) {
func (loader *Loader) LoadFromData(data []byte) (*T, error) {
loader.resetVisitedPathItemRefs()
doc := &T{}
if err := unmarshal(data, doc, loader.IncludeOrigin); err != nil {
if err := unmarshal(data, doc, IncludeOrigin); err != nil {
return nil, err
}
if err := loader.ResolveRefsIn(doc, nil); err != nil {
Expand Down Expand Up @@ -171,7 +173,7 @@ func (loader *Loader) loadFromDataWithPathInternal(data []byte, location *url.UR
doc := &T{}
loader.visitedDocuments[uri] = doc

if err := unmarshal(data, doc, loader.IncludeOrigin); err != nil {
if err := unmarshal(data, doc, IncludeOrigin); err != nil {
return nil, err
}

Expand Down Expand Up @@ -425,7 +427,7 @@ func (loader *Loader) resolveComponent(doc *T, ref string, path *url.URL, resolv
if err2 != nil {
return nil, nil, err
}
if err2 = unmarshal(data, &cursor, loader.IncludeOrigin); err2 != nil {
if err2 = unmarshal(data, &cursor, IncludeOrigin); err2 != nil {
return nil, nil, err
}
if cursor, err2 = drill(cursor); err2 != nil || cursor == nil {
Expand Down
88 changes: 78 additions & 10 deletions openapi3/origin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@ import (
"github.com/stretchr/testify/require"
)

func unsetIncludeOrigin() {
IncludeOrigin = false
}

func TestOrigin_Info(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/simple.yaml")
Expand Down Expand Up @@ -42,7 +49,10 @@ func TestOrigin_Info(t *testing.T) {
func TestOrigin_Paths(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/simple.yaml")
Expand Down Expand Up @@ -78,7 +88,10 @@ func TestOrigin_Paths(t *testing.T) {
func TestOrigin_RequestBody(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/request_body.yaml")
Expand All @@ -105,7 +118,10 @@ func TestOrigin_RequestBody(t *testing.T) {
func TestOrigin_Responses(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/simple.yaml")
Expand Down Expand Up @@ -140,7 +156,10 @@ func TestOrigin_Responses(t *testing.T) {
func TestOrigin_Parameters(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/parameters.yaml")
Expand Down Expand Up @@ -173,7 +192,10 @@ func TestOrigin_Parameters(t *testing.T) {
func TestOrigin_SchemaInAdditionalProperties(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/additional_properties.yaml")
Expand Down Expand Up @@ -201,7 +223,10 @@ func TestOrigin_SchemaInAdditionalProperties(t *testing.T) {
func TestOrigin_ExternalDocs(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/external_docs.yaml")
Expand Down Expand Up @@ -235,7 +260,10 @@ func TestOrigin_ExternalDocs(t *testing.T) {
func TestOrigin_Security(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/security.yaml")
Expand Down Expand Up @@ -283,7 +311,10 @@ func TestOrigin_Security(t *testing.T) {
func TestOrigin_Example(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/example.yaml")
Expand Down Expand Up @@ -319,7 +350,10 @@ func TestOrigin_Example(t *testing.T) {
func TestOrigin_XML(t *testing.T) {
loader := NewLoader()
loader.IsExternalRefsAllowed = true
loader.IncludeOrigin = true

IncludeOrigin = true
defer unsetIncludeOrigin()

loader.Context = context.Background()

doc, err := loader.LoadFromFile("testdata/origin/xml.yaml")
Expand Down Expand Up @@ -348,3 +382,37 @@ func TestOrigin_XML(t *testing.T) {
},
base.Origin.Fields["prefix"])
}

// TestOrigin_OriginExistsInProperties verifies that loading fails when a specification
// contains a property named "__origin__", highlighting a limitation in the current implementation.
func TestOrigin_OriginExistsInProperties(t *testing.T) {
var data = `
paths:
/foo:
get:
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: "#/components/schemas/Foo"
components:
schemas:
Foo:
type: object
properties:
__origin__:
type: string
`

loader := NewLoader()

IncludeOrigin = true
defer unsetIncludeOrigin()

_, err := loader.LoadFromData([]byte(data))
require.Error(t, err)
require.Equal(t, `failed to unmarshal data: json error: invalid character 'p' looking for beginning of value, yaml error: error converting YAML to JSON: yaml: unmarshal errors:
line 0: mapping key "__origin__" already defined at line 17`, err.Error())
}
20 changes: 16 additions & 4 deletions openapi3/stringmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,10 @@ func unmarshalStringMapP[V any](data []byte) (map[string]*V, *Origin, error) {
return nil, nil, err
}

origin, err := deepCast[Origin](m[originKey])
origin, err := popOrigin(m, originKey)
if err != nil {
return nil, nil, err
}
delete(m, originKey)

result := make(map[string]*V, len(m))
for k, v := range m {
Expand All @@ -43,11 +42,10 @@ func unmarshalStringMap[V any](data []byte) (map[string]V, *Origin, error) {
return nil, nil, err
}

origin, err := deepCast[Origin](m[originKey])
origin, err := popOrigin(m, originKey)
if err != nil {
return nil, nil, err
}
delete(m, originKey)

result := make(map[string]V, len(m))
for k, v := range m {
Expand All @@ -74,3 +72,17 @@ func deepCast[V any](value any) (*V, error) {
}
return &result, nil
}

// popOrigin removes the origin from the map and returns it.
func popOrigin(m map[string]any, key string) (*Origin, error) {
if !IncludeOrigin {
return nil, nil
}

origin, err := deepCast[Origin](m[key])
if err != nil {
return nil, err
}
delete(m, key)
return origin, nil
}
Loading