Skip to content

Commit 332c9ae

Browse files
authored
fix: stateless tests (#1244)
* fix: stateless tests Signed-off-by: Alessio Pragliola <[email protected]> * fix: unit tests Signed-off-by: Alessio Pragliola <[email protected]> * fix: other stateless tests Signed-off-by: Alessio Pragliola <[email protected]> * fix: unit tests Signed-off-by: Alessio Pragliola <[email protected]> * fix: 409 status code missing errors Signed-off-by: Alessio Pragliola <[email protected]> * fix: nil pointer panics Signed-off-by: Alessio Pragliola <[email protected]> --------- Signed-off-by: Alessio Pragliola <[email protected]>
1 parent bcbcb59 commit 332c9ae

23 files changed

+185
-53
lines changed

internal/core/artifact.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package core
22

33
import (
4+
"errors"
45
"fmt"
56
"strconv"
67

@@ -11,6 +12,7 @@ import (
1112
"github.com/kubeflow/model-registry/internal/mapper"
1213
"github.com/kubeflow/model-registry/pkg/api"
1314
"github.com/kubeflow/model-registry/pkg/openapi"
15+
"gorm.io/gorm"
1416
)
1517

1618
type ModelRegistryService struct {
@@ -66,7 +68,7 @@ func (b *ModelRegistryService) upsertArtifact(artifact *openapi.Artifact, modelV
6668
} else {
6769
convertedId, err := strconv.ParseInt(*modelVersionId, 10, 32)
6870
if err != nil {
69-
return nil, fmt.Errorf("invalid model version id: %w", err)
71+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
7072
}
7173

7274
convertedIdInt32 := int32(convertedId)
@@ -104,11 +106,15 @@ func (b *ModelRegistryService) upsertArtifact(artifact *openapi.Artifact, modelV
104106

105107
modelArtifact, err := b.mapper.MapFromModelArtifact(ma)
106108
if err != nil {
107-
return nil, err
109+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
108110
}
109111

110112
modelArtifact, err = b.modelArtifactRepository.Save(modelArtifact, modelVersionIDPtr)
111113
if err != nil {
114+
if errors.Is(err, gorm.ErrDuplicatedKey) {
115+
return nil, fmt.Errorf("model artifact with name %s already exists: %w", *ma.Name, api.ErrConflict)
116+
}
117+
112118
return nil, err
113119
}
114120

@@ -150,11 +156,15 @@ func (b *ModelRegistryService) upsertArtifact(artifact *openapi.Artifact, modelV
150156

151157
docArtifact, err := b.mapper.MapFromDocArtifact(da)
152158
if err != nil {
153-
return nil, err
159+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
154160
}
155161

156162
docArtifact, err = b.docArtifactRepository.Save(docArtifact, modelVersionIDPtr)
157163
if err != nil {
164+
if errors.Is(err, gorm.ErrDuplicatedKey) {
165+
return nil, fmt.Errorf("doc artifact with name %s already exists: %w", *da.Name, api.ErrConflict)
166+
}
167+
158168
return nil, err
159169
}
160170

@@ -188,7 +198,7 @@ func (b *ModelRegistryService) getArtifact(id string, preserveName bool) (*opena
188198

189199
artifact, err := b.artifactRepository.GetByID(int32(convertedId))
190200
if err != nil {
191-
return nil, err
201+
return nil, fmt.Errorf("no artifact found for id %s: %w", id, api.ErrNotFound)
192202
}
193203

194204
if artifact.ModelArtifact != nil {
@@ -280,7 +290,7 @@ func (b *ModelRegistryService) GetArtifacts(listOptions api.ListOptions, modelVe
280290
if modelVersionId != nil {
281291
convertedId, err := strconv.ParseInt(*modelVersionId, 10, 32)
282292
if err != nil {
283-
return nil, fmt.Errorf("invalid model version id: %w", err)
293+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
284294
}
285295

286296
convertedIdInt32 := int32(convertedId)
@@ -397,7 +407,7 @@ func (b *ModelRegistryService) GetModelArtifacts(listOptions api.ListOptions, mo
397407
if modelVersionId != nil {
398408
convertedId, err := strconv.ParseInt(*modelVersionId, 10, 32)
399409
if err != nil {
400-
return nil, fmt.Errorf("invalid model version id: %w", err)
410+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
401411
}
402412

403413
convertedIdInt32 := int32(convertedId)

internal/core/artifact_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ func TestUpsertModelVersionArtifact(t *testing.T) {
453453

454454
assert.Error(t, err)
455455
assert.Nil(t, result)
456-
assert.Contains(t, err.Error(), "invalid model version id")
456+
assert.Contains(t, err.Error(), "invalid syntax: bad request")
457457
})
458458

459459
t.Run("unicode characters in model version artifact name", func(t *testing.T) {
@@ -975,7 +975,7 @@ func TestGetArtifacts(t *testing.T) {
975975

976976
assert.Error(t, err)
977977
assert.Nil(t, result)
978-
assert.Contains(t, err.Error(), "invalid model version id")
978+
assert.Contains(t, err.Error(), "invalid syntax: bad request")
979979
})
980980
}
981981

@@ -1419,7 +1419,7 @@ func TestGetModelArtifacts(t *testing.T) {
14191419

14201420
assert.Error(t, err)
14211421
assert.Nil(t, result)
1422-
assert.Contains(t, err.Error(), "invalid model version id")
1422+
assert.Contains(t, err.Error(), "invalid syntax: bad request")
14231423
})
14241424
}
14251425

internal/core/inference_service.go

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package core
22

33
import (
4+
"errors"
45
"fmt"
56
"strconv"
67

@@ -10,6 +11,7 @@ import (
1011
"github.com/kubeflow/model-registry/internal/db/models"
1112
"github.com/kubeflow/model-registry/pkg/api"
1213
"github.com/kubeflow/model-registry/pkg/openapi"
14+
"gorm.io/gorm"
1315
)
1416

1517
func (b *ModelRegistryService) UpsertInferenceService(inferenceService *openapi.InferenceService) (*openapi.InferenceService, error) {
@@ -30,16 +32,31 @@ func (b *ModelRegistryService) UpsertInferenceService(inferenceService *openapi.
3032
inferenceService = &withNotEditable
3133
}
3234

35+
_, err := b.GetServingEnvironmentById(inferenceService.ServingEnvironmentId)
36+
if err != nil {
37+
return nil, fmt.Errorf("no serving environment found for id %s: %w", inferenceService.ServingEnvironmentId, api.ErrNotFound)
38+
}
39+
3340
infSvc, err := b.mapper.MapFromInferenceService(inferenceService, inferenceService.ServingEnvironmentId)
3441
if err != nil {
35-
return nil, err
42+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
3643
}
3744

38-
prefixedName := converter.PrefixWhenOwned(&inferenceService.ServingEnvironmentId, *infSvc.GetAttributes().Name)
45+
name := ""
46+
47+
if infSvc.GetAttributes().Name != nil {
48+
name = *infSvc.GetAttributes().Name
49+
}
50+
51+
prefixedName := converter.PrefixWhenOwned(&inferenceService.ServingEnvironmentId, name)
3952
infSvc.GetAttributes().Name = &prefixedName
4053

4154
savedInfSvc, err := b.inferenceServiceRepository.Save(infSvc)
4255
if err != nil {
56+
if errors.Is(err, gorm.ErrDuplicatedKey) {
57+
return nil, fmt.Errorf("inference service with name %s already exists: %w", *infSvc.GetAttributes().Name, api.ErrConflict)
58+
}
59+
4360
return nil, err
4461
}
4562

@@ -114,7 +131,7 @@ func (b *ModelRegistryService) GetInferenceServices(listOptions api.ListOptions,
114131
if servingEnvironmentId != nil {
115132
convertedId, err := strconv.ParseInt(*servingEnvironmentId, 10, 32)
116133
if err != nil {
117-
return nil, fmt.Errorf("invalid serving environment id: %w", err)
134+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
118135
}
119136

120137
id := int32(convertedId)
@@ -143,7 +160,7 @@ func (b *ModelRegistryService) GetInferenceServices(listOptions api.ListOptions,
143160
for _, infSvc := range infServicesList.Items {
144161
inferenceService, err := b.mapper.MapToInferenceService(infSvc)
145162
if err != nil {
146-
return nil, err
163+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
147164
}
148165
inferenceServiceList.Items = append(inferenceServiceList.Items, *inferenceService)
149166
}

internal/core/inference_service_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ func TestGetInferenceServices(t *testing.T) {
881881

882882
assert.Error(t, err)
883883
assert.Nil(t, result)
884-
assert.Contains(t, err.Error(), "invalid serving environment id")
884+
assert.Contains(t, err.Error(), "invalid syntax: bad request")
885885
})
886886
}
887887

internal/core/model_version.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package core
22

33
import (
4+
"errors"
45
"fmt"
56
"strconv"
67

@@ -10,6 +11,7 @@ import (
1011
"github.com/kubeflow/model-registry/internal/db/models"
1112
"github.com/kubeflow/model-registry/pkg/api"
1213
"github.com/kubeflow/model-registry/pkg/openapi"
14+
"gorm.io/gorm"
1315
)
1416

1517
func (b *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVersion, registeredModelId *string) (*openapi.ModelVersion, error) {
@@ -36,13 +38,17 @@ func (b *ModelRegistryService) UpsertModelVersion(modelVersion *openapi.ModelVer
3638

3739
model, err := b.mapper.MapFromModelVersion(modelVersion)
3840
if err != nil {
39-
return nil, err
41+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
4042
}
4143

4244
modelVersion.Name = converter.PrefixWhenOwned(&modelVersion.RegisteredModelId, modelVersion.Name)
4345

4446
savedModel, err := b.modelVersionRepository.Save(model)
4547
if err != nil {
48+
if errors.Is(err, gorm.ErrDuplicatedKey) {
49+
return nil, fmt.Errorf("model version with name %s already exists: %w", modelVersion.Name, api.ErrConflict)
50+
}
51+
4652
return nil, err
4753
}
4854

@@ -78,12 +84,12 @@ func (b *ModelRegistryService) GetModelVersionById(id string) (*openapi.ModelVer
7884
func (b *ModelRegistryService) GetModelVersionByInferenceService(inferenceServiceId string) (*openapi.ModelVersion, error) {
7985
convertedId, err := strconv.ParseInt(inferenceServiceId, 10, 32)
8086
if err != nil {
81-
return nil, fmt.Errorf("invalid inference service id: %w", err)
87+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
8288
}
8389

8490
infSvc, err := b.inferenceServiceRepository.GetByID(int32(convertedId))
8591
if err != nil {
86-
return nil, err
92+
return nil, fmt.Errorf("no inference service found for id %s: %w", inferenceServiceId, api.ErrNotFound)
8793
}
8894

8995
infSvcProps := infSvc.GetProperties()
@@ -168,7 +174,7 @@ func (b *ModelRegistryService) GetModelVersions(listOptions api.ListOptions, reg
168174
if registeredModelId != nil {
169175
convertedId, err := strconv.ParseInt(*registeredModelId, 10, 32)
170176
if err != nil {
171-
return nil, fmt.Errorf("invalid registered model id: %w", err)
177+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
172178
}
173179

174180
id := int32(convertedId)
@@ -195,7 +201,7 @@ func (b *ModelRegistryService) GetModelVersions(listOptions api.ListOptions, reg
195201
for _, model := range versionsList.Items {
196202
modelVersion, err := b.mapper.MapToModelVersion(model)
197203
if err != nil {
198-
return nil, err
204+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
199205
}
200206
modelVersionList.Items = append(modelVersionList.Items, *modelVersion)
201207
}

internal/core/model_version_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -585,7 +585,7 @@ func TestGetModelVersionByInferenceService(t *testing.T) {
585585

586586
assert.Error(t, err)
587587
assert.Nil(t, result)
588-
assert.Contains(t, err.Error(), "invalid inference service id")
588+
assert.Contains(t, err.Error(), "invalid syntax: bad request")
589589
})
590590

591591
t.Run("non-existent inference service", func(t *testing.T) {
@@ -798,7 +798,7 @@ func TestGetModelVersions(t *testing.T) {
798798

799799
assert.Error(t, err)
800800
assert.Nil(t, result)
801-
assert.Contains(t, err.Error(), "invalid registered model id")
801+
assert.Contains(t, err.Error(), "invalid syntax: bad request")
802802
})
803803
}
804804

internal/core/registered_model.go

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package core
22

33
import (
4+
"errors"
45
"fmt"
56
"strconv"
67

@@ -9,6 +10,7 @@ import (
910
"github.com/kubeflow/model-registry/internal/db/models"
1011
"github.com/kubeflow/model-registry/pkg/api"
1112
"github.com/kubeflow/model-registry/pkg/openapi"
13+
"gorm.io/gorm"
1214
)
1315

1416
func (b *ModelRegistryService) UpsertRegisteredModel(registeredModel *openapi.RegisteredModel) (*openapi.RegisteredModel, error) {
@@ -31,11 +33,15 @@ func (b *ModelRegistryService) UpsertRegisteredModel(registeredModel *openapi.Re
3133

3234
model, err := b.mapper.MapFromRegisteredModel(registeredModel)
3335
if err != nil {
34-
return nil, err
36+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
3537
}
3638

3739
savedModel, err := b.registeredModelRepository.Save(model)
3840
if err != nil {
41+
if errors.Is(err, gorm.ErrDuplicatedKey) {
42+
return nil, fmt.Errorf("registered model with name %s already exists: %w", registeredModel.Name, api.ErrConflict)
43+
}
44+
3945
return nil, err
4046
}
4147

@@ -45,7 +51,7 @@ func (b *ModelRegistryService) UpsertRegisteredModel(registeredModel *openapi.Re
4551
func (b *ModelRegistryService) GetRegisteredModelById(id string) (*openapi.RegisteredModel, error) {
4652
convertedId, err := strconv.ParseInt(id, 10, 32)
4753
if err != nil {
48-
return nil, fmt.Errorf("invalid id: %w", err)
54+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
4955
}
5056

5157
model, err := b.registeredModelRepository.GetByID(int32(convertedId))
@@ -59,12 +65,12 @@ func (b *ModelRegistryService) GetRegisteredModelById(id string) (*openapi.Regis
5965
func (b *ModelRegistryService) GetRegisteredModelByInferenceService(inferenceServiceId string) (*openapi.RegisteredModel, error) {
6066
convertedId, err := strconv.ParseInt(inferenceServiceId, 10, 32)
6167
if err != nil {
62-
return nil, fmt.Errorf("invalid inference service id: %w", err)
68+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
6369
}
6470

6571
infSvc, err := b.inferenceServiceRepository.GetByID(int32(convertedId))
6672
if err != nil {
67-
return nil, err
73+
return nil, fmt.Errorf("no inference service found for id %s: %w", inferenceServiceId, api.ErrNotFound)
6874
}
6975

7076
infSvcProps := infSvc.GetProperties()
@@ -115,7 +121,12 @@ func (b *ModelRegistryService) GetRegisteredModelByParams(name *string, external
115121
return nil, fmt.Errorf("multiple registered models found for name=%v, externalId=%v: %w", apiutils.ZeroIfNil(name), apiutils.ZeroIfNil(externalId), api.ErrNotFound)
116122
}
117123

118-
return b.mapper.MapToRegisteredModel(modelsList.Items[0])
124+
registeredModel, err := b.mapper.MapToRegisteredModel(modelsList.Items[0])
125+
if err != nil {
126+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
127+
}
128+
129+
return registeredModel, nil
119130
}
120131

121132
func (b *ModelRegistryService) GetRegisteredModels(listOptions api.ListOptions) (*openapi.RegisteredModelList, error) {
@@ -138,7 +149,7 @@ func (b *ModelRegistryService) GetRegisteredModels(listOptions api.ListOptions)
138149
for _, model := range modelsList.Items {
139150
registeredModel, err := b.mapper.MapToRegisteredModel(model)
140151
if err != nil {
141-
return nil, err
152+
return nil, fmt.Errorf("%v: %w", err, api.ErrBadRequest)
142153
}
143154
registeredModelList.Items = append(registeredModelList.Items, *registeredModel)
144155
}

internal/core/registered_model_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ func TestGetRegisteredModelByInferenceService(t *testing.T) {
410410

411411
assert.Error(t, err)
412412
assert.Nil(t, result)
413-
assert.Contains(t, err.Error(), "invalid inference service id")
413+
assert.Contains(t, err.Error(), "invalid syntax: bad request")
414414
})
415415

416416
t.Run("non-existent inference service", func(t *testing.T) {

0 commit comments

Comments
 (0)