Skip to content

Commit 23c8cf8

Browse files
authored
Add tensor type filter to basic filters (#2135)
* Add tensor type filter to basic filters Signed-off-by: ppadti <[email protected]> * Address review comments Signed-off-by: ppadti <[email protected]> --------- Signed-off-by: ppadti <[email protected]>
1 parent e55a322 commit 23c8cf8

File tree

9 files changed

+233
-44
lines changed

9 files changed

+233
-44
lines changed

clients/ui/frontend/src/__mocks__/mockCatalogFilterOptionsList.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import {
99
AllLanguageCode,
1010
UseCaseOptionValue,
1111
DEFAULT_PERFORMANCE_FILTERS_QUERY_NAME,
12+
ModelCatalogTensorType,
1213
} from '~/concepts/modelCatalog/const';
1314

1415
export const mockNamedQueries: Record<string, NamedQuery> = {
@@ -99,6 +100,16 @@ export const mockCatalogFilterOptionsList = (
99100
AllLanguageCode.ZH,
100101
],
101102
},
103+
[ModelCatalogStringFilterKey.TENSOR_TYPE]: {
104+
type: 'string',
105+
values: [
106+
ModelCatalogTensorType.FP16,
107+
ModelCatalogTensorType.FP8,
108+
ModelCatalogTensorType.INT4,
109+
ModelCatalogTensorType.INT8,
110+
ModelCatalogTensorType.MXFP4,
111+
],
112+
},
102113
[ModelCatalogStringFilterKey.HARDWARE_TYPE]: {
103114
type: 'string',
104115
values: ['GPU', 'CPU', 'TPU', 'FPGA'],

clients/ui/frontend/src/__tests__/cypress/cypress/tests/mocked/modelCatalog/modelCatalog.cy.ts

Lines changed: 73 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,25 @@ type HandlersProps = {
6969
includeAllModelsIntercept?: boolean;
7070
};
7171

72+
const calculateExpectedCategoryCount = (sources: CatalogSource[]): number => {
73+
const uniqueLabels = new Set<string>();
74+
sources.forEach((source) => {
75+
source.labels.forEach((label) => {
76+
if (label.trim()) {
77+
uniqueLabels.add(label.trim());
78+
}
79+
});
80+
});
81+
82+
const hasSourcesWithoutLabels = sources.some(
83+
(source) =>
84+
source.enabled !== false &&
85+
(source.labels.length === 0 || source.labels.every((label) => !label.trim())),
86+
);
87+
88+
return uniqueLabels.size + (hasSourcesWithoutLabels ? 1 : 0);
89+
};
90+
7291
const initIntercepts = ({
7392
sources = [mockCatalogSource({}), mockCatalogSource({ id: 'source-2', name: 'source 2' })],
7493
modelsPerCategory = 4,
@@ -213,6 +232,7 @@ describe('Model Catalog Page', () => {
213232
modelCatalog.findFilter('License').should('be.visible');
214233
modelCatalog.findFilter('Task').should('be.visible');
215234
modelCatalog.findFilter('Language').should('be.visible');
235+
modelCatalog.findFilter('Tensor type').scrollIntoView().should('be.visible');
216236
});
217237

218238
it('filters show more and show less button should work', () => {
@@ -245,24 +265,8 @@ describe('Model Catalog Page', () => {
245265
mockCatalogSource({}),
246266
mockCatalogSource({ id: 'source-2', name: 'source 2' }),
247267
];
248-
const uniqueLabels = new Set<string>();
249-
defaultSources.forEach((source) => {
250-
source.labels.forEach((label) => {
251-
if (label.trim()) {
252-
uniqueLabels.add(label.trim());
253-
}
254-
});
255-
});
256-
257-
// Check if there are sources without labels
258-
const hasSourcesWithoutLabels = defaultSources.some(
259-
(source) =>
260-
source.enabled !== false &&
261-
(source.labels.length === 0 || source.labels.every((label) => !label.trim())),
262-
);
263268

264-
// Expected count: unique labels + (1 if sources without labels exist)
265-
const expectedCategoryCount = uniqueLabels.size + (hasSourcesWithoutLabels ? 1 : 0);
269+
const expectedCategoryCount = calculateExpectedCategoryCount(defaultSources);
266270

267271
initIntercepts({ sources: defaultSources, includeAllModelsIntercept: false });
268272

@@ -286,6 +290,58 @@ describe('Model Catalog Page', () => {
286290
);
287291
});
288292
});
293+
294+
it('tensor type filter checkbox should work', () => {
295+
initIntercepts({ includeAllModelsIntercept: true });
296+
297+
setupFilteredModelsIntercept({
298+
returnModelsForFilters: true,
299+
modelsToReturn: [mockCatalogModel({})],
300+
});
301+
302+
modelCatalog.visit();
303+
modelCatalog.findFilterCheckbox('Tensor type', 'FP16').click();
304+
cy.wait('@getFilteredModels');
305+
306+
modelCatalog.findFilterCheckbox('Tensor type', 'INT8').click();
307+
308+
cy.wait('@getFilteredModels').then((interception) => {
309+
expect(interception.request.url).to.include(
310+
'tensor_type.string_value+IN+%28%27FP16%27%2C%27INT8%27%29',
311+
);
312+
});
313+
});
314+
315+
it('tensor type filter combined with other filters should work', () => {
316+
const defaultSources = [
317+
mockCatalogSource({}),
318+
mockCatalogSource({ id: 'source-2', name: 'source 2' }),
319+
];
320+
321+
const expectedCategoryCount = calculateExpectedCategoryCount(defaultSources);
322+
323+
initIntercepts({ sources: defaultSources, includeAllModelsIntercept: false });
324+
325+
setupFilteredModelsIntercept({
326+
returnModelsForFilters: true,
327+
modelsToReturn: [mockCatalogModel({})],
328+
});
329+
330+
modelCatalog.visit();
331+
modelCatalog.findFilterShowMoreButton('Task').click();
332+
modelCatalog.findFilterCheckbox('Task', 'text-generation').click();
333+
modelCatalog.findFilterCheckbox('Tensor type', 'FP16').click();
334+
modelCatalog.findFilterCheckbox('Provider', 'Google').click();
335+
336+
const waitCalls = Array.from({ length: expectedCategoryCount }, () => '@getFilteredModels');
337+
cy.wait(waitCalls).then((interceptions) => {
338+
const lastInterception = interceptions[interceptions.length - 1];
339+
const { url } = lastInterception.request;
340+
expect(url).to.include('tasks%3D%27text-generation%27');
341+
expect(url).to.include('tensor_type.string_value%3D%27FP16%27');
342+
expect(url).to.include('provider%3D%27Google%27');
343+
});
344+
});
289345
});
290346

291347
describe('Performance Empty State', () => {

clients/ui/frontend/src/app/context/modelCatalog/ModelCatalogContext.tsx

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ export const ModelCatalogContext = React.createContext<ModelCatalogContextType>(
8282
[ModelCatalogStringFilterKey.HARDWARE_CONFIGURATION]: [],
8383
[ModelCatalogStringFilterKey.USE_CASE]: [],
8484
[ModelCatalogNumberFilterKey.MAX_RPS]: undefined,
85+
[ModelCatalogStringFilterKey.TENSOR_TYPE]: [],
8586
},
8687
updateSelectedSource: () => undefined,
8788
selectedSourceLabel: undefined,
@@ -124,6 +125,7 @@ export const ModelCatalogContextProvider: React.FC<ModelCatalogContextProviderPr
124125
[ModelCatalogStringFilterKey.HARDWARE_CONFIGURATION]: [],
125126
[ModelCatalogStringFilterKey.USE_CASE]: [],
126127
[ModelCatalogNumberFilterKey.MAX_RPS]: undefined,
128+
[ModelCatalogStringFilterKey.TENSOR_TYPE]: [],
127129
});
128130
const [filterOptions, filterOptionsLoaded, filterOptionsLoadError] =
129131
useCatalogFilterOptionList(apiState);
@@ -173,14 +175,15 @@ export const ModelCatalogContextProvider: React.FC<ModelCatalogContextProviderPr
173175
}, [filterOptions?.namedQueries, applyNamedQueryDefaults, baseSetFilterData]);
174176

175177
/**
176-
* Clears basic filters (Task, Provider, License, Language) to empty.
178+
* Clears basic filters (Task, Provider, License, Language, Tensor Type) to empty.
177179
* Note: BASIC_FILTER_KEYS in const.ts should be updated if basic filters change.
178180
*/
179181
const clearBasicFilters = React.useCallback(() => {
180182
baseSetFilterData(ModelCatalogStringFilterKey.TASK, []);
181183
baseSetFilterData(ModelCatalogStringFilterKey.PROVIDER, []);
182184
baseSetFilterData(ModelCatalogStringFilterKey.LICENSE, []);
183185
baseSetFilterData(ModelCatalogStringFilterKey.LANGUAGE, []);
186+
baseSetFilterData(ModelCatalogStringFilterKey.TENSOR_TYPE, []);
184187
}, [baseSetFilterData]);
185188

186189
/**

clients/ui/frontend/src/app/modelCatalogTypes.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import {
1010
LatencyPropertyKey,
1111
UseCaseOptionValue,
1212
ModelCatalogFilterKey,
13+
ModelCatalogTensorType,
1314
} from '~/concepts/modelCatalog/const';
1415
import {
1516
ModelRegistryCustomProperties,
@@ -236,6 +237,7 @@ export type ModelCatalogStringFilterValueType = {
236237
[ModelCatalogStringFilterKey.PROVIDER]: ModelCatalogProvider;
237238
[ModelCatalogStringFilterKey.LICENSE]: ModelCatalogLicense;
238239
[ModelCatalogStringFilterKey.LANGUAGE]: AllLanguageCode;
240+
[ModelCatalogStringFilterKey.TENSOR_TYPE]: ModelCatalogTensorType;
239241
[ModelCatalogStringFilterKey.HARDWARE_TYPE]: string;
240242
[ModelCatalogStringFilterKey.HARDWARE_CONFIGURATION]: string;
241243
[ModelCatalogStringFilterKey.USE_CASE]: UseCaseOptionValue;
@@ -301,6 +303,7 @@ export type ModelCatalogFilterStates = {
301303
[ModelCatalogStringFilterKey.PROVIDER]: ModelCatalogProvider[];
302304
[ModelCatalogStringFilterKey.LICENSE]: ModelCatalogLicense[];
303305
[ModelCatalogStringFilterKey.LANGUAGE]: AllLanguageCode[];
306+
[ModelCatalogStringFilterKey.TENSOR_TYPE]: ModelCatalogTensorType[];
304307
[ModelCatalogStringFilterKey.HARDWARE_TYPE]: string[];
305308
[ModelCatalogStringFilterKey.HARDWARE_CONFIGURATION]: string[];
306309
[ModelCatalogStringFilterKey.USE_CASE]: UseCaseOptionValue[];

clients/ui/frontend/src/app/pages/modelCatalog/components/ModelCatalogFilters.tsx

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import TaskFilter from './globalFilters/TaskFilter';
77
import ProviderFilter from './globalFilters/ProviderFilter';
88
import LicenseFilter from './globalFilters/LicenseFilter';
99
import LanguageFilter from './globalFilters/LanguageFilter';
10+
import TensorTypeFilter from './globalFilters/TensorTypeFilter';
1011

1112
const ModelCatalogFilters: React.FC = () => {
1213
const { filterOptions, filterOptionsLoaded, filterOptionsLoadError } =
@@ -22,21 +23,18 @@ const ModelCatalogFilters: React.FC = () => {
2223
</Alert>
2324
);
2425
}
26+
27+
const getFilterProps = (filterKey: ModelCatalogStringFilterKey) =>
28+
filters && filterKey in filters ? filters : undefined;
29+
2530
return (
2631
<Stack hasGutter>
2732
<ModelPerformanceViewToggleCard />
28-
<TaskFilter
29-
filters={filters && ModelCatalogStringFilterKey.TASK in filters ? filters : undefined}
30-
/>
31-
<ProviderFilter
32-
filters={filters && ModelCatalogStringFilterKey.PROVIDER in filters ? filters : undefined}
33-
/>
34-
<LicenseFilter
35-
filters={filters && ModelCatalogStringFilterKey.LICENSE in filters ? filters : undefined}
36-
/>
37-
<LanguageFilter
38-
filters={filters && ModelCatalogStringFilterKey.LANGUAGE in filters ? filters : undefined}
39-
/>
33+
<TaskFilter filters={getFilterProps(ModelCatalogStringFilterKey.TASK)} />
34+
<ProviderFilter filters={getFilterProps(ModelCatalogStringFilterKey.PROVIDER)} />
35+
<LicenseFilter filters={getFilterProps(ModelCatalogStringFilterKey.LICENSE)} />
36+
<LanguageFilter filters={getFilterProps(ModelCatalogStringFilterKey.LANGUAGE)} />
37+
<TensorTypeFilter filters={getFilterProps(ModelCatalogStringFilterKey.TENSOR_TYPE)} />
4038
</Stack>
4139
);
4240
};

clients/ui/frontend/src/app/pages/modelCatalog/components/globalFilters/LanguageFilter.tsx

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import * as React from 'react';
2-
import { StackItem } from '@patternfly/react-core';
2+
import { Divider, StackItem } from '@patternfly/react-core';
33
import ModelCatalogStringFilter from '~/app/pages/modelCatalog/components/ModelCatalogStringFilter';
44
import {
55
ModelCatalogStringFilterKey,
@@ -29,14 +29,17 @@ const LanguageFilter: React.FC<LanguageFilterProps> = ({ filters }) => {
2929
}
3030

3131
return (
32-
<StackItem>
33-
<ModelCatalogStringFilter<ModelCatalogStringFilterKey.LANGUAGE>
34-
title="Language"
35-
filterKey={filterKey}
36-
filterToNameMapping={LANGUAGE_NAME_MAPPING}
37-
filters={language}
38-
/>
39-
</StackItem>
32+
<>
33+
<StackItem>
34+
<ModelCatalogStringFilter<ModelCatalogStringFilterKey.LANGUAGE>
35+
title="Language"
36+
filterKey={filterKey}
37+
filterToNameMapping={LANGUAGE_NAME_MAPPING}
38+
filters={language}
39+
/>
40+
</StackItem>
41+
<Divider />
42+
</>
4043
);
4144
};
4245

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import * as React from 'react';
2+
import { StackItem } from '@patternfly/react-core';
3+
import { CatalogFilterOptions, ModelCatalogStringFilterOptions } from '~/app/modelCatalogTypes';
4+
import { ModelCatalogStringFilterKey, ModelCatalogTensorType } from '~/concepts/modelCatalog/const';
5+
import ModelCatalogStringFilter from '~/app/pages/modelCatalog/components/ModelCatalogStringFilter';
6+
7+
const filterKey = ModelCatalogStringFilterKey.TENSOR_TYPE;
8+
9+
type TensorTypeFilterProps = {
10+
filters?: Extract<CatalogFilterOptions, Partial<ModelCatalogStringFilterOptions>>;
11+
};
12+
13+
const TensorTypeFilter: React.FC<TensorTypeFilterProps> = ({ filters }) => {
14+
const tensorType = filters?.[filterKey];
15+
16+
if (!tensorType) {
17+
return null;
18+
}
19+
20+
return (
21+
<StackItem>
22+
<ModelCatalogStringFilter<ModelCatalogStringFilterKey.TENSOR_TYPE>
23+
title="Tensor type"
24+
filterKey={filterKey}
25+
filters={tensorType}
26+
filterToNameMapping={ModelCatalogTensorType}
27+
/>
28+
</StackItem>
29+
);
30+
};
31+
32+
export default TensorTypeFilter;

0 commit comments

Comments
 (0)