diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registerAndStoreFields.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registerAndStoreFields.ts index f367e4c1fd..8d4a6ddaaa 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registerAndStoreFields.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/pages/modelRegistryView/registerAndStoreFields.ts @@ -301,9 +301,18 @@ class RegisterAndStoreFields { this.findSourceS3SecretAccessKeyInput().type(secretAccessKey); } + /** Sets model type (required on register page). Uses Predictive by default. */ + selectModelType( + optionName: 'Predictive Model' | 'Generative AI model (Example, LLM)' = 'Predictive Model', + ) { + cy.get('#register-model-type-toggle').click(); + cy.findByRole('option', { name: optionName }).click(); + } + // Convenience method to fill all required fields for submission fillAllRequiredFields() { this.fillModelName('test-model'); + this.selectModelType(); this.fillVersionName('v1.0.0'); this.fillJobName('my-transfer-job'); this.fillSourceEndpoint('https://s3.amazonaws.com'); diff --git a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerAndStoreFields.cy.ts b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerAndStoreFields.cy.ts index ba136c6ca1..fd7cbbd5a6 100644 --- a/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerAndStoreFields.cy.ts +++ b/clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelRegistry/registerAndStoreFields.cy.ts @@ -331,6 +331,7 @@ describe('Register and Store Fields - Credential Validation', () => { it('Should have submit button disabled when S3 access key ID is missing', () => { // Fill all fields except S3 access key ID registerAndStoreFields.fillModelName('test-model'); + registerAndStoreFields.selectModelType(); registerAndStoreFields.fillVersionName('v1.0.0'); registerAndStoreFields.fillJobName('my-transfer-job'); registerAndStoreFields.fillSourceEndpoint('https://s3.amazonaws.com'); @@ -349,6 +350,7 @@ describe('Register and Store Fields - Credential Validation', () => { it('Should have submit button disabled when S3 secret access key is missing', () => { // Fill all fields except S3 secret access key registerAndStoreFields.fillModelName('test-model'); + registerAndStoreFields.selectModelType(); registerAndStoreFields.fillVersionName('v1.0.0'); registerAndStoreFields.fillJobName('my-transfer-job'); registerAndStoreFields.fillSourceEndpoint('https://s3.amazonaws.com'); @@ -367,6 +369,7 @@ describe('Register and Store Fields - Credential Validation', () => { it('Should have submit button disabled when OCI username is missing', () => { // Fill all fields except OCI username registerAndStoreFields.fillModelName('test-model'); + registerAndStoreFields.selectModelType(); registerAndStoreFields.fillVersionName('v1.0.0'); registerAndStoreFields.fillJobName('my-transfer-job'); registerAndStoreFields.fillSourceEndpoint('https://s3.amazonaws.com'); @@ -385,6 +388,7 @@ describe('Register and Store Fields - Credential Validation', () => { it('Should have submit button disabled when OCI password is missing', () => { // Fill all fields except OCI password registerAndStoreFields.fillModelName('test-model'); + registerAndStoreFields.selectModelType(); registerAndStoreFields.fillVersionName('v1.0.0'); registerAndStoreFields.fillJobName('my-transfer-job'); registerAndStoreFields.fillSourceEndpoint('https://s3.amazonaws.com'); diff --git a/clients/ui/frontend/src/app/pages/modelCatalog/screens/RegisterCatalogModelForm.tsx b/clients/ui/frontend/src/app/pages/modelCatalog/screens/RegisterCatalogModelForm.tsx index 169a81a612..bee06a69f4 100644 --- a/clients/ui/frontend/src/app/pages/modelCatalog/screens/RegisterCatalogModelForm.tsx +++ b/clients/ui/frontend/src/app/pages/modelCatalog/screens/RegisterCatalogModelForm.tsx @@ -36,6 +36,7 @@ import { import { CatalogArtifacts, CatalogModel, CatalogModelDetailsParams } from '~/app/modelCatalogTypes'; import { getCatalogModelDetailsRoute } from '~/app/routes/modelCatalog/catalogModelDetails'; import { + getCatalogModelTypePropertyForRegistration, getModelArtifactUri, getModelName, } from '~/app/pages/modelCatalog/utils/modelCatalogUtils'; @@ -81,7 +82,11 @@ const RegisterCatalogModelForm: React.FC = ({ jobResourceName: '', modelRegistry: preferredModelRegistry.name, namespace: '', - modelCustomProperties: { ...getLabelsFromCustomProperties(model?.customProperties), ...tasks }, + modelCustomProperties: { + ...getLabelsFromCustomProperties(model?.customProperties), + ...tasks, + ...getCatalogModelTypePropertyForRegistration(model?.customProperties), + }, versionCustomProperties: { ...model?.customProperties, License: { diff --git a/clients/ui/frontend/src/app/pages/modelCatalog/utils/__tests__/modelCatalogUtils.spec.ts b/clients/ui/frontend/src/app/pages/modelCatalog/utils/__tests__/modelCatalogUtils.spec.ts index beaa330722..8b14c561cc 100644 --- a/clients/ui/frontend/src/app/pages/modelCatalog/utils/__tests__/modelCatalogUtils.spec.ts +++ b/clients/ui/frontend/src/app/pages/modelCatalog/utils/__tests__/modelCatalogUtils.spec.ts @@ -19,6 +19,8 @@ import { ModelCatalogTask, ModelCatalogTensorType, UseCaseOptionValue, + CatalogModelCustomPropertyKey, + ModelType, } from '~/concepts/modelCatalog/const'; import { CatalogSourceStatus, @@ -34,6 +36,7 @@ import { hasFiltersApplied, getArchitecturesFromArtifacts, getModelName, + getCatalogModelTypePropertyForRegistration, getActiveSourceLabels, } from '~/app/pages/modelCatalog/utils/modelCatalogUtils'; import { mockCatalogModelArtifact } from '~/__mocks__/mockCatalogModelArtifactList'; @@ -1387,6 +1390,37 @@ describe('getModelName', () => { }); }); +describe('getCatalogModelTypePropertyForRegistration', () => { + it('returns model_type metadata when catalog has generative or predictive', () => { + expect( + getCatalogModelTypePropertyForRegistration({ + [CatalogModelCustomPropertyKey.MODEL_TYPE]: { + metadataType: ModelRegistryMetadataType.STRING, + string_value: ModelType.GENERATIVE, + }, + }), + ).toEqual({ + [CatalogModelCustomPropertyKey.MODEL_TYPE]: { + metadataType: ModelRegistryMetadataType.STRING, + string_value: ModelType.GENERATIVE, + }, + }); + }); + + it('returns empty object when model_type is absent or not a registerable value', () => { + expect(getCatalogModelTypePropertyForRegistration(undefined)).toEqual({}); + expect(getCatalogModelTypePropertyForRegistration({})).toEqual({}); + expect( + getCatalogModelTypePropertyForRegistration({ + [CatalogModelCustomPropertyKey.MODEL_TYPE]: { + metadataType: ModelRegistryMetadataType.STRING, + string_value: ModelType.UNKNOWN, + }, + }), + ).toEqual({}); + }); +}); + describe('getActiveSourceLabels', () => { const createSource = (overrides: Partial = {}): CatalogSource => ({ id: 'source-1', diff --git a/clients/ui/frontend/src/app/pages/modelCatalog/utils/modelCatalogUtils.ts b/clients/ui/frontend/src/app/pages/modelCatalog/utils/modelCatalogUtils.ts index 133c8d4e1d..a9c3496bc0 100644 --- a/clients/ui/frontend/src/app/pages/modelCatalog/utils/modelCatalogUtils.ts +++ b/clients/ui/frontend/src/app/pages/modelCatalog/utils/modelCatalogUtils.ts @@ -35,7 +35,11 @@ import { ModelType, } from '~/concepts/modelCatalog/const'; import { isSourceStatusWithModels } from '~/concepts/modelCatalogSettings/const'; -import { ModelRegistryMetadataType } from '~/app/types'; +import { ModelRegistryCustomProperties, ModelRegistryMetadataType } from '~/app/types'; +import { + buildCustomPropertiesWithModelType, + getModelTypeStoredValueFromCustomProperties, +} from '~/app/pages/modelRegistry/screens/RegisterModel/registerModelTypeUtils'; /** * Prefix used by the backend for artifact-specific filter options. @@ -789,3 +793,14 @@ export const formatModelTypeDisplay = (modelTypeRaw: string | null): string => { // Fallback: capitalize whatever value we got return capitalize(modelTypeRaw.trim()); }; + +/** + * Returns model registry customProperties entries to prefill `model_type` when registering + * from the catalog. Only generative/predictive are copied; unknown or missing values yield {}. + */ +export const getCatalogModelTypePropertyForRegistration = ( + customProperties?: ModelRegistryCustomProperties, +): ModelRegistryCustomProperties => { + const stored = getModelTypeStoredValueFromCustomProperties(customProperties); + return buildCustomPropertiesWithModelType(undefined, stored); +}; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsView.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsView.tsx index 174efd6ecc..de6063c1e7 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsView.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersionDetails/ModelVersionDetailsView.tsx @@ -105,6 +105,7 @@ const ModelVersionDetailsView: React.FC = ({ refresh={refresh} isArchiveModel={isArchiveVersion} isExpandable + modelTypeFallbackCustomProperties={mv.customProperties} /> )} diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsCard.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsCard.tsx index bd96f3bd95..525a172366 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsCard.tsx +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/ModelVersions/ModelDetailsCard.tsx @@ -20,18 +20,22 @@ import { DashboardDescriptionListGroup, EditableLabelsDescriptionListGroup, } from 'mod-arch-shared'; -import { RegisteredModel } from '~/app/types'; +import { ModelRegistryCustomProperties, RegisteredModel } from '~/app/types'; import ModelTimestamp from '~/app/pages/modelRegistry/screens/components/ModelTimestamp'; import ModelPropertiesExpandableSection from '~/app/pages/modelRegistry/screens/components/ModelPropertiesExpandableSection'; import { ModelRegistryContext } from '~/app/context/ModelRegistryContext'; import { getLabels, mergeUpdatedLabels } from '~/app/pages/modelRegistry/screens/utils'; import { EMPTY_CUSTOM_PROPERTY_VALUE } from '~/concepts/modelCatalog/const'; +import { formatModelTypeDisplay } from '~/app/pages/modelCatalog/utils/modelCatalogUtils'; +import { getModelTypeRawStringFromCustomProperties } from '~/app/pages/modelRegistry/screens/RegisterModel/registerModelTypeUtils'; type ModelDetailsCardProps = { registeredModel: RegisteredModel; refresh: () => void; isArchiveModel?: boolean; isExpandable?: boolean; + /** If `model_type` is absent on the registered model, read from these (e.g. version custom properties). */ + modelTypeFallbackCustomProperties?: ModelRegistryCustomProperties; }; const ModelDetailsCard: React.FC = ({ @@ -39,6 +43,7 @@ const ModelDetailsCard: React.FC = ({ refresh, isArchiveModel, isExpandable, + modelTypeFallbackCustomProperties, }) => { const { apiState } = React.useContext(ModelRegistryContext); const [isExpanded, setIsExpanded] = React.useState(false); @@ -112,8 +117,17 @@ const ModelDetailsCard: React.FC = ({ /> ); + const modelTypeRaw = + getModelTypeRawStringFromCustomProperties(rm.customProperties) ?? + getModelTypeRawStringFromCustomProperties(modelTypeFallbackCustomProperties); + const infoSection = ( <> + + + {formatModelTypeDisplay(modelTypeRaw)} + + { registeredModels, namespaceHasAccess, isNamespaceAccessLoading, + { requireModelType: true }, ); const handleSubmit = async () => { @@ -160,6 +161,7 @@ const RegisterModel: React.FC = () => { setData={setData} hasModelNameError={hasModelNameError} isModelNameDuplicate={isModelNameDuplicate} + isModelTypeRequired /> = { @@ -18,12 +19,15 @@ type RegisterModelDetailsFormSectionProp = { setData: UpdateObjectAtPropAndValue; hasModelNameError: boolean; isModelNameDuplicate?: boolean; + /** When true (MR register-from-registry), model type is required; submit stays disabled until set. */ + isModelTypeRequired?: boolean; }; const RegisterModelDetailsFormSection = ({ formData, setData, hasModelNameError, isModelNameDuplicate, + isModelTypeRequired = false, }: RegisterModelDetailsFormSectionProp): React.ReactNode => { const modelNameInput = ( ({ + setData('modelCustomProperties', next)} + isRequired={isModelTypeRequired} + /> ); }; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisterModelTypeField.tsx b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisterModelTypeField.tsx new file mode 100644 index 0000000000..0b5820d7d7 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/RegisterModelTypeField.tsx @@ -0,0 +1,66 @@ +import { FormGroup } from '@patternfly/react-core'; +import React from 'react'; +import { SimpleSelect } from 'mod-arch-shared'; +import { SimpleSelectOption } from 'mod-arch-shared/dist/components/SimpleSelect'; +import { ModelRegistryCustomProperties } from '~/app/types'; +import { ModelType } from '~/concepts/modelCatalog/const'; +import { formatModelTypeDisplay } from '~/app/pages/modelCatalog/utils/modelCatalogUtils'; +import FormFieldset from '~/app/pages/modelRegistry/screens/components/FormFieldset'; +import { + buildCustomPropertiesWithModelType, + getModelTypeStoredValueFromCustomProperties, +} from './registerModelTypeUtils'; + +const MODEL_TYPE_SELECT_OPTIONS: SimpleSelectOption[] = [ + { + key: ModelType.GENERATIVE, + label: formatModelTypeDisplay(ModelType.GENERATIVE), + }, + { + key: ModelType.PREDICTIVE, + label: formatModelTypeDisplay(ModelType.PREDICTIVE), + }, +]; + +type RegisterModelTypeFieldProps = { + modelCustomProperties: ModelRegistryCustomProperties | undefined; + onModelCustomPropertiesChange: (next: ModelRegistryCustomProperties) => void; + isRequired?: boolean; +}; + +const RegisterModelTypeField: React.FC = ({ + modelCustomProperties, + onModelCustomPropertiesChange, + isRequired, +}) => { + const stored = getModelTypeStoredValueFromCustomProperties(modelCustomProperties); + + const handleChange = (key: string) => { + if (key === ModelType.GENERATIVE || key === ModelType.PREDICTIVE) { + onModelCustomPropertiesChange(buildCustomPropertiesWithModelType(modelCustomProperties, key)); + } + }; + + return ( + + + } + /> + + ); +}; + +export default RegisterModelTypeField; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/__tests__/registerModelTypeUtils.spec.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/__tests__/registerModelTypeUtils.spec.ts new file mode 100644 index 0000000000..907d391b77 --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/__tests__/registerModelTypeUtils.spec.ts @@ -0,0 +1,158 @@ +import { + ModelRegistryCustomProperties, + ModelRegistryCustomPropertyString, + ModelRegistryMetadataType, +} from '~/app/types'; +import { ModelType } from '~/concepts/modelCatalog/const'; +import { + buildCustomPropertiesWithModelType, + getModelTypeRawStringFromCustomProperties, + getModelTypeStoredValueFromCustomProperties, + MODEL_TYPE_CUSTOM_PROPERTY_KEY, +} from '~/app/pages/modelRegistry/screens/RegisterModel/registerModelTypeUtils'; + +const stringProp = (value: string): ModelRegistryCustomPropertyString => ({ + metadataType: ModelRegistryMetadataType.STRING, + // eslint-disable-next-line camelcase + string_value: value, +}); + +describe('registerModelTypeUtils', () => { + describe('getModelTypeRawStringFromCustomProperties', () => { + it('returns null when custom properties are undefined', () => { + expect(getModelTypeRawStringFromCustomProperties(undefined)).toBeNull(); + }); + + it('returns null when model_type is missing', () => { + expect(getModelTypeRawStringFromCustomProperties({})).toBeNull(); + }); + + it('returns null when metadata is not STRING', () => { + expect( + getModelTypeRawStringFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: { + metadataType: ModelRegistryMetadataType.INT, + // eslint-disable-next-line camelcase + int_value: '1', + }, + }), + ).toBeNull(); + }); + + it('returns null when string_value is empty or whitespace', () => { + expect( + getModelTypeRawStringFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(''), + }), + ).toBeNull(); + expect( + getModelTypeRawStringFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(' '), + }), + ).toBeNull(); + }); + + it('returns trimmed raw string without normalizing case', () => { + expect( + getModelTypeRawStringFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(' Generative '), + }), + ).toBe('Generative'); + }); + }); + + describe('getModelTypeStoredValueFromCustomProperties', () => { + it('returns undefined when custom properties are undefined', () => { + expect(getModelTypeStoredValueFromCustomProperties(undefined)).toBeUndefined(); + }); + + it('returns undefined when model_type is missing or not STRING', () => { + expect(getModelTypeStoredValueFromCustomProperties({})).toBeUndefined(); + expect( + getModelTypeStoredValueFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: { + metadataType: ModelRegistryMetadataType.INT, + // eslint-disable-next-line camelcase + int_value: '1', + }, + }), + ).toBeUndefined(); + }); + + it('returns undefined for STRING values that are not generative or predictive', () => { + expect( + getModelTypeStoredValueFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.UNKNOWN), + }), + ).toBeUndefined(); + expect( + getModelTypeStoredValueFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp('other'), + }), + ).toBeUndefined(); + }); + + it('returns generative or predictive when string matches after lowercasing and trim', () => { + expect( + getModelTypeStoredValueFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.GENERATIVE), + }), + ).toBe(ModelType.GENERATIVE); + expect( + getModelTypeStoredValueFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(' GENERATIVE '), + }), + ).toBe(ModelType.GENERATIVE); + expect( + getModelTypeStoredValueFromCustomProperties({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.PREDICTIVE), + }), + ).toBe(ModelType.PREDICTIVE); + }); + }); + + describe('buildCustomPropertiesWithModelType', () => { + it('returns empty object when base is undefined and model type is cleared', () => { + expect(buildCustomPropertiesWithModelType(undefined, undefined)).toEqual({}); + }); + + it('sets model_type when next is generative or predictive', () => { + expect(buildCustomPropertiesWithModelType(undefined, ModelType.GENERATIVE)).toEqual({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.GENERATIVE), + }); + expect(buildCustomPropertiesWithModelType(undefined, ModelType.PREDICTIVE)).toEqual({ + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.PREDICTIVE), + }); + }); + + it('merges with base properties and overwrites model_type', () => { + const base: ModelRegistryCustomProperties = { + otherKey: stringProp('keep-me'), + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.GENERATIVE), + }; + expect(buildCustomPropertiesWithModelType(base, ModelType.PREDICTIVE)).toEqual({ + otherKey: stringProp('keep-me'), + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.PREDICTIVE), + }); + }); + + it('removes model_type when next is undefined and preserves other keys', () => { + const base: ModelRegistryCustomProperties = { + otherKey: stringProp('keep-me'), + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.GENERATIVE), + }; + expect(buildCustomPropertiesWithModelType(base, undefined)).toEqual({ + otherKey: stringProp('keep-me'), + }); + }); + + it('does not mutate the base object', () => { + const base: ModelRegistryCustomProperties = { + [MODEL_TYPE_CUSTOM_PROPERTY_KEY]: stringProp(ModelType.GENERATIVE), + }; + const copy = { ...base }; + buildCustomPropertiesWithModelType(base, ModelType.PREDICTIVE); + expect(base).toEqual(copy); + }); + }); +}); diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/__tests__/utils.spec.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/__tests__/utils.spec.ts index 2d9cfa3a13..ce8a0c2f78 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/__tests__/utils.spec.ts +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/__tests__/utils.spec.ts @@ -1,5 +1,6 @@ import { RegisteredModelList, + ModelRegistryMetadataType, ModelTransferJobSourceType, ModelTransferJobDestinationType, ModelTransferJobUploadIntent, @@ -9,12 +10,127 @@ import { isModelNameExisting, isNameValid, buildModelTransferJobPayload, + isRegisterModelSubmitDisabled, + isRegisterCatalogModelSubmitDisabled, } from '~/app/pages/modelRegistry/screens/RegisterModel/utils'; import { MR_CHARACTER_LIMIT } from '~/app/pages/modelRegistry/screens/RegisterModel/const'; -import { ModelLocationType } from '~/app/pages/modelRegistry/screens/RegisterModel/useRegisterModelData'; +import { + ModelLocationType, + RegisterModelFormData, +} from '~/app/pages/modelRegistry/screens/RegisterModel/useRegisterModelData'; import { RegistrationMode } from '~/app/pages/modelRegistry/screens/const'; +import { CatalogModelCustomPropertyKey, ModelType } from '~/concepts/modelCatalog/const'; + +/** Shared fields for registration utils tests (transfer job + submit-disabled). */ +const registrationFormTestBase = { + versionName: 'v1.0.0', + versionDescription: 'Test version', + sourceModelFormat: 'onnx', + sourceModelFormatVersion: '1.0', + modelLocationType: ModelLocationType.ObjectStorage, + modelLocationEndpoint: 'https://s3.amazonaws.com', + modelLocationBucket: 'test-bucket', + modelLocationRegion: 'us-east-1', + modelLocationPath: 'models/test', + modelLocationURI: '', + modelLocationS3AccessKeyId: '', + modelLocationS3SecretAccessKey: '', + registrationMode: RegistrationMode.RegisterAndStore, + namespace: 'test-namespace', + destinationOciRegistry: 'quay.io', + destinationOciUsername: '', + destinationOciPassword: '', + destinationOciUri: 'quay.io/org/model:v1', + jobName: 'test-job', + jobResourceName: 'test-job-resource', + versionCustomProperties: {}, + additionalArtifactProperties: {}, +}; describe('RegisterModel utils', () => { + const emptyRegisteredModelList = { + items: [], + size: 0, + pageSize: 20, + nextPageToken: '', + } as RegisteredModelList; + + /** Register + URI path: only model type should gate MR submit when required. */ + const mrRegisterForm = ( + modelCustomProperties: RegisterModelFormData['modelCustomProperties'], + ): RegisterModelFormData => ({ + ...registrationFormTestBase, + modelName: 'unique-new-model', + modelDescription: '', + registrationMode: RegistrationMode.Register, + modelLocationType: ModelLocationType.URI, + modelLocationURI: 'https://example.com/model.onnx', + modelLocationEndpoint: '', + modelLocationBucket: '', + modelLocationRegion: '', + modelLocationPath: '', + namespace: '', + destinationOciRegistry: '', + destinationOciUsername: '', + destinationOciPassword: '', + destinationOciUri: '', + jobName: '', + jobResourceName: '', + modelCustomProperties, + versionCustomProperties: {}, + }); + + describe('isRegisterModelSubmitDisabled (model type)', () => { + it('disables submit until model type is selected when requireModelType is true', () => { + expect( + isRegisterModelSubmitDisabled( + mrRegisterForm({}), + emptyRegisteredModelList, + undefined, + undefined, + { + requireModelType: true, + }, + ), + ).toBe(true); + }); + + it('allows submit once model type is set when requireModelType is true', () => { + expect( + isRegisterModelSubmitDisabled( + mrRegisterForm({ + [CatalogModelCustomPropertyKey.MODEL_TYPE]: { + metadataType: ModelRegistryMetadataType.STRING, + // eslint-disable-next-line camelcase + string_value: ModelType.GENERATIVE, + }, + }), + emptyRegisteredModelList, + undefined, + undefined, + { requireModelType: true }, + ), + ).toBe(false); + }); + + it('does not require model type by default', () => { + expect(isRegisterModelSubmitDisabled(mrRegisterForm({}), emptyRegisteredModelList)).toBe( + false, + ); + }); + }); + + describe('isRegisterCatalogModelSubmitDisabled', () => { + it('allows submit without model type when registry is selected', () => { + expect( + isRegisterCatalogModelSubmitDisabled( + { ...mrRegisterForm({}), modelRegistry: 'test-mr' }, + emptyRegisteredModelList, + ), + ).toBe(false); + }); + }); + describe('isModelNameExisting', () => { const existingModelName = 'model2'; const newModelName = 'model4'; @@ -41,34 +157,12 @@ describe('RegisterModel utils', () => { }); describe('buildModelTransferJobPayload', () => { - const baseFormData = { - versionName: 'v1.0.0', - versionDescription: 'Test version', - sourceModelFormat: 'onnx', - sourceModelFormatVersion: '1.0', - modelLocationType: ModelLocationType.ObjectStorage, - modelLocationEndpoint: 'https://s3.amazonaws.com', - modelLocationBucket: 'test-bucket', - modelLocationRegion: 'us-east-1', - modelLocationPath: 'models/test', - modelLocationURI: '', - modelLocationS3AccessKeyId: '', - modelLocationS3SecretAccessKey: '', - registrationMode: RegistrationMode.RegisterAndStore, - namespace: 'test-namespace', - destinationOciRegistry: 'quay.io', - destinationOciUsername: '', - destinationOciPassword: '', - destinationOciUri: 'quay.io/org/model:v1', - - jobName: 'test-job', - jobResourceName: 'test-job-resource', - versionCustomProperties: {}, - additionalArtifactProperties: {}, - }; - it('should build payload with S3 source for ObjectStorage location type', () => { - const formData = { ...baseFormData, modelName: 'Test Model', modelDescription: '' }; + const formData = { + ...registrationFormTestBase, + modelName: 'Test Model', + modelDescription: '', + }; const payload = buildModelTransferJobPayload( formData, 'test-author', @@ -85,7 +179,7 @@ describe('RegisterModel utils', () => { it('should build payload with URI source for URI location type', () => { const formData = { - ...baseFormData, + ...registrationFormTestBase, modelName: 'Test Model', modelDescription: '', modelLocationType: ModelLocationType.URI, @@ -102,7 +196,11 @@ describe('RegisterModel utils', () => { }); it('should build OCI destination correctly', () => { - const formData = { ...baseFormData, modelName: 'Test Model', modelDescription: '' }; + const formData = { + ...registrationFormTestBase, + modelName: 'Test Model', + modelDescription: '', + }; const payload = buildModelTransferJobPayload( formData, 'test-author', @@ -117,7 +215,11 @@ describe('RegisterModel utils', () => { }); it('should set CREATE_MODEL intent and include model name', () => { - const formData = { ...baseFormData, modelName: 'My New Model', modelDescription: '' }; + const formData = { + ...registrationFormTestBase, + modelName: 'My New Model', + modelDescription: '', + }; const payload = buildModelTransferJobPayload( formData, 'test-author', @@ -129,7 +231,7 @@ describe('RegisterModel utils', () => { }); it('should set CREATE_VERSION intent with registeredModelId', () => { - const formData = { ...baseFormData, registeredModelId: 'existing-model-123' }; + const formData = { ...registrationFormTestBase, registeredModelId: 'existing-model-123' }; const payload = buildModelTransferJobPayload( formData, 'test-author', @@ -144,7 +246,11 @@ describe('RegisterModel utils', () => { }); it('should include namespace, author, and job resource name', () => { - const formData = { ...baseFormData, modelName: 'Test Model', modelDescription: '' }; + const formData = { + ...registrationFormTestBase, + modelName: 'Test Model', + modelDescription: '', + }; const payload = buildModelTransferJobPayload( formData, 'test-author', @@ -157,7 +263,11 @@ describe('RegisterModel utils', () => { }); it('should set PENDING status and omit server-generated fields', () => { - const formData = { ...baseFormData, modelName: 'Test Model', modelDescription: '' }; + const formData = { + ...registrationFormTestBase, + modelName: 'Test Model', + modelDescription: '', + }; const payload = buildModelTransferJobPayload( formData, 'test-author', diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/registerModelTypeUtils.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/registerModelTypeUtils.ts new file mode 100644 index 0000000000..9aac8596ea --- /dev/null +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/registerModelTypeUtils.ts @@ -0,0 +1,49 @@ +import { CatalogModelCustomPropertyKey, ModelType } from '~/concepts/modelCatalog/const'; +import { ModelRegistryCustomProperties, ModelRegistryMetadataType } from '~/app/types'; + +export const MODEL_TYPE_CUSTOM_PROPERTY_KEY = CatalogModelCustomPropertyKey.MODEL_TYPE; + +export type RegisterableModelType = ModelType.GENERATIVE | ModelType.PREDICTIVE; + +/** Raw `model_type` string for display (any non-empty STRING metadata), or null if unset. */ +export const getModelTypeRawStringFromCustomProperties = ( + customProperties: ModelRegistryCustomProperties | undefined, +): string | null => { + const prop = customProperties?.[MODEL_TYPE_CUSTOM_PROPERTY_KEY]; + if (!prop || prop.metadataType !== ModelRegistryMetadataType.STRING) { + return null; + } + const v = prop.string_value.trim(); + return v || null; +}; + +export const getModelTypeStoredValueFromCustomProperties = ( + props: ModelRegistryCustomProperties | undefined, +): RegisterableModelType | undefined => { + const prop = props?.[MODEL_TYPE_CUSTOM_PROPERTY_KEY]; + if (!prop || prop.metadataType !== ModelRegistryMetadataType.STRING) { + return undefined; + } + const v = prop.string_value.toLowerCase().trim(); + if (v === ModelType.GENERATIVE || v === ModelType.PREDICTIVE) { + return v; + } + return undefined; +}; + +export const buildCustomPropertiesWithModelType = ( + base: ModelRegistryCustomProperties | undefined, + next: RegisterableModelType | undefined, +): ModelRegistryCustomProperties => { + const result = { ...(base ?? {}) }; + if (!next) { + delete result[MODEL_TYPE_CUSTOM_PROPERTY_KEY]; + } else { + result[MODEL_TYPE_CUSTOM_PROPERTY_KEY] = { + metadataType: ModelRegistryMetadataType.STRING, + // eslint-disable-next-line camelcase + string_value: next, + }; + } + return result; +}; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/utils.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/utils.ts index 37fd649473..5f0604d0b7 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/utils.ts +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/RegisterModel/utils.ts @@ -25,6 +25,7 @@ import { RegistrationCommonFormData, } from './useRegisterModelData'; import { RegistrationErrorType, MR_CHARACTER_LIMIT } from './const'; +import { getModelTypeStoredValueFromCustomProperties } from './registerModelTypeUtils'; export type RegisterModelCreatedResources = RegisterVersionCreatedResources & { registeredModel?: RegisteredModel; @@ -195,16 +196,24 @@ const isSubmitDisabledForCommonFields = ( ); }; +/** Submit disabled check for register-model flow. Pass `{ requireModelType: true }` when the UI collects model type. */ export const isRegisterModelSubmitDisabled = ( formData: RegisterModelFormData, registeredModels: RegisteredModelList, namespaceHasAccess?: boolean, isNamespaceAccessLoading?: boolean, -): boolean => - !formData.modelName || - isSubmitDisabledForCommonFields(formData, namespaceHasAccess, isNamespaceAccessLoading) || - !isNameValid(formData.modelName) || - isModelNameExisting(formData.modelName, registeredModels); + options?: { requireModelType?: boolean }, +): boolean => { + const requireModelType = options?.requireModelType ?? false; + return ( + !formData.modelName || + isSubmitDisabledForCommonFields(formData, namespaceHasAccess, isNamespaceAccessLoading) || + !isNameValid(formData.modelName) || + isModelNameExisting(formData.modelName, registeredModels) || + (requireModelType && + !getModelTypeStoredValueFromCustomProperties(formData.modelCustomProperties)) + ); +}; export const isRegisterVersionSubmitDisabled = ( formData: RegisterVersionFormData, @@ -225,6 +234,7 @@ export const isRegisterCatalogModelSubmitDisabled = ( registeredModels, namespaceHasAccess, isNamespaceAccessLoading, + { requireModelType: false }, ) || !formData.modelRegistry; export const isNameValid = (name: string): boolean => name.length <= MR_CHARACTER_LIMIT; diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/__tests__/utils.spec.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/__tests__/utils.spec.ts index a6ecc94535..7ccfdeda14 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/__tests__/utils.spec.ts +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/__tests__/utils.spec.ts @@ -177,10 +177,11 @@ describe('getProperties', () => { }); }); - it('should return with _lastModified and _registeredFrom props filtered out', () => { + it('should return with _lastModified, _registeredFrom, and model_type props filtered out', () => { const customProperties: ModelRegistryCustomProperties = { property1: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'non-empty' }, _lastModified: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'non-empty' }, + model_type: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'predictive' }, _registeredFromSomething: { metadataType: ModelRegistryMetadataType.STRING, string_value: 'non-empty', diff --git a/clients/ui/frontend/src/app/pages/modelRegistry/screens/utils.ts b/clients/ui/frontend/src/app/pages/modelRegistry/screens/utils.ts index bdbc714eeb..f8c6e78fb6 100644 --- a/clients/ui/frontend/src/app/pages/modelRegistry/screens/utils.ts +++ b/clients/ui/frontend/src/app/pages/modelRegistry/screens/utils.ts @@ -81,7 +81,7 @@ export const getPropertyValue = ( } }; -// Retrieves the customProperties that are not special (_registeredFrom) or labels (they have a defined string_value). +// Retrieves the customProperties that are not special (_registeredFrom/model_type) or labels (they have a defined string_value). // Now includes INT and DOUBLE types in addition to STRING export const getProperties = ( customProperties: T, @@ -90,7 +90,7 @@ export const getProperties = ( return Object.keys(customProperties).reduce((acc, key) => { // _lastModified is a property that is required to update the timestamp on the backend and we have a workaround for it. It should be resolved by // backend team - if (key === '_lastModified' || /^_registeredFrom/.test(key)) { + if (key === '_lastModified' || key === 'model_type' || /^_registeredFrom/.test(key)) { return acc; }