Skip to content

Commit 9355642

Browse files
authored
Fetch numNodes from ClusterTrainingRuntime (opendatahub-io#5508)
* Fetch numNodes from ClusterTrainingRuntime * resolving type check
1 parent 7e4df91 commit 9355642

10 files changed

Lines changed: 383 additions & 12 deletions

File tree

frontend/src/api/models/kubeflow.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@ export const TrainJobModel: K8sModelCommon = {
1313
kind: 'TrainJob',
1414
plural: 'trainjobs',
1515
};
16+
17+
export const ClusterTrainingRuntimeModel: K8sModelCommon = {
18+
apiVersion: 'v1alpha1',
19+
apiGroup: 'trainer.kubeflow.org',
20+
kind: 'ClusterTrainingRuntime',
21+
plural: 'clustertrainingruntimes',
22+
};
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import { ClusterTrainingRuntimeKind } from '@odh-dashboard/model-training/k8sTypes';
2+
3+
type MockClusterTrainingRuntimeConfigType = {
4+
name?: string;
5+
numNodes?: number;
6+
numProcPerNode?: string | number;
7+
framework?: string;
8+
};
9+
10+
export const mockClusterTrainingRuntimeK8sResource = ({
11+
name = 'training-cuda128-torch28-py312',
12+
numNodes = 1,
13+
numProcPerNode = 'auto',
14+
framework = 'torch',
15+
}: MockClusterTrainingRuntimeConfigType = {}): ClusterTrainingRuntimeKind => ({
16+
apiVersion: 'trainer.kubeflow.org/v1alpha1',
17+
kind: 'ClusterTrainingRuntime',
18+
metadata: {
19+
name,
20+
creationTimestamp: '2025-11-05T14:36:39Z',
21+
labels: {
22+
'app.kubernetes.io/component': 'controller',
23+
'app.kubernetes.io/name': 'trainer',
24+
[`trainer.kubeflow.org/framework`]: framework,
25+
},
26+
resourceVersion: '277649620',
27+
generation: 2,
28+
},
29+
spec: {
30+
mlPolicy: {
31+
numNodes,
32+
torch: {
33+
numProcPerNode,
34+
},
35+
},
36+
template: {
37+
spec: {
38+
replicatedJobs: [
39+
{
40+
groupName: 'default',
41+
name: 'node',
42+
replicas: 1,
43+
template: {
44+
metadata: {
45+
labels: {
46+
'trainer.kubeflow.org/trainjob-ancestor-step': 'trainer',
47+
},
48+
},
49+
spec: {
50+
template: {
51+
metadata: {},
52+
spec: {
53+
containers: [
54+
{
55+
image: 'quay.io/rhoai/odh-training-cuda128-torch28-py312-rhel9:rhoai-3.0',
56+
name: 'node',
57+
resources: {},
58+
},
59+
],
60+
},
61+
},
62+
},
63+
},
64+
},
65+
],
66+
},
67+
},
68+
},
69+
});

packages/model-training/src/api.ts

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import {
22
k8sDeleteResource,
33
K8sStatus,
44
k8sPatchResource,
5+
k8sGetResource,
56
} from '@openshift/dynamic-plugin-sdk-utils';
67
import { applyK8sAPIOptions } from '@odh-dashboard/internal/api/apiMergeUtils';
78
import { K8sAPIOptions, WorkloadKind } from '@odh-dashboard/internal/k8sTypes';
@@ -10,8 +11,11 @@ import { WorkloadModel } from '@odh-dashboard/internal/api/models/kueue';
1011
import { groupVersionKind } from '@odh-dashboard/internal/api/k8sUtils';
1112
import { CustomWatchK8sResult } from '@odh-dashboard/internal/types';
1213
import useK8sWatchResourceList from '@odh-dashboard/internal/utilities/useK8sWatchResourceList';
13-
import { TrainJobModel } from '@odh-dashboard/internal/api/models/kubeflow';
14-
import { TrainJobKind } from './k8sTypes';
14+
import {
15+
TrainJobModel,
16+
ClusterTrainingRuntimeModel,
17+
} from '@odh-dashboard/internal/api/models/kubeflow';
18+
import { TrainJobKind, ClusterTrainingRuntimeKind } from './k8sTypes';
1519

1620
export const useTrainJobs = (namespace: string): CustomWatchK8sResult<TrainJobKind[]> =>
1721
useK8sWatchResourceList(
@@ -166,3 +170,17 @@ export const toggleTrainJobHibernation = async (
166170
};
167171
}
168172
};
173+
174+
export const getClusterTrainingRuntime = (
175+
name: string,
176+
opts?: K8sAPIOptions,
177+
): Promise<ClusterTrainingRuntimeKind> =>
178+
k8sGetResource<ClusterTrainingRuntimeKind>(
179+
applyK8sAPIOptions(
180+
{
181+
model: ClusterTrainingRuntimeModel,
182+
queryOptions: { name },
183+
},
184+
opts,
185+
),
186+
);

packages/model-training/src/global/trainingJobDetailsDrawer/TrainingJobResourcesTab.tsx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ const TrainingJobResourcesTab: React.FC<TrainingJobResourcesTabProps> = ({ job }
5252
isDisabled // TODO: RHOAIENG-37576 Uncomment this when scaling is implemented
5353
data-testid="nodes-edit-button"
5454
>
55-
{job.spec.trainer.numNodes}
55+
{job.spec.trainer?.numNodes}
5656
</Button>
5757
</DescriptionListDescription>
5858
</DescriptionListGroup>
@@ -61,7 +61,7 @@ const TrainingJobResourcesTab: React.FC<TrainingJobResourcesTabProps> = ({ job }
6161
Processes per node:
6262
</DescriptionListTerm>
6363
<DescriptionListDescription data-testid="processes-per-node-value">
64-
{job.spec.trainer.numProcPerNode || '-'}
64+
{job.spec.trainer?.numProcPerNode || '-'}
6565
</DescriptionListDescription>
6666
</DescriptionListGroup>
6767
</DescriptionList>
@@ -76,29 +76,29 @@ const TrainingJobResourcesTab: React.FC<TrainingJobResourcesTabProps> = ({ job }
7676
CPU requests:
7777
</DescriptionListTerm>
7878
<DescriptionListDescription data-testid="cpu-requests-value">
79-
{job.spec.trainer.resourcesPerNode?.requests?.cpu || '-'}
79+
{job.spec.trainer?.resourcesPerNode?.requests?.cpu || '-'}
8080
</DescriptionListDescription>
8181
</DescriptionListGroup>
8282
<DescriptionListGroup>
8383
<DescriptionListTerm style={{ fontWeight: 'normal' }}>CPU limits:</DescriptionListTerm>
8484
<DescriptionListDescription data-testid="cpu-limits-value">
85-
{job.spec.trainer.resourcesPerNode?.limits?.cpu || '-'}
85+
{job.spec.trainer?.resourcesPerNode?.limits?.cpu || '-'}
8686
</DescriptionListDescription>
8787
</DescriptionListGroup>
8888
<DescriptionListGroup>
8989
<DescriptionListTerm style={{ fontWeight: 'normal' }}>
9090
Memory requests:
9191
</DescriptionListTerm>
9292
<DescriptionListDescription data-testid="memory-requests-value">
93-
{job.spec.trainer.resourcesPerNode?.requests?.memory || '-'}
93+
{job.spec.trainer?.resourcesPerNode?.requests?.memory || '-'}
9494
</DescriptionListDescription>
9595
</DescriptionListGroup>
9696
<DescriptionListGroup>
9797
<DescriptionListTerm style={{ fontWeight: 'normal' }}>
9898
Memory limits:
9999
</DescriptionListTerm>
100100
<DescriptionListDescription data-testid="memory-limits-value">
101-
{job.spec.trainer.resourcesPerNode?.limits?.memory || '-'}
101+
{job.spec.trainer?.resourcesPerNode?.limits?.memory || '-'}
102102
</DescriptionListDescription>
103103
</DescriptionListGroup>
104104
</DescriptionList>

packages/model-training/src/global/trainingJobList/TrainingJobClusterQueue.tsx

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ const TrainingJobClusterQueue: React.FC<TrainingJobClusterQueueProps> = ({
1616
namespace,
1717
);
1818

19+
if (!localQueueName) {
20+
return '-';
21+
}
22+
1923
if (!clusterQueueLoaded) {
2024
return <Skeleton width="100px" />;
2125
}

packages/model-training/src/global/trainingJobList/TrainingJobTableRow.tsx

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import StateActionToggle from './StateActionToggle';
1414
import { TrainJobKind } from '../../k8sTypes';
1515
import { TrainingJobState } from '../../types';
1616
import { toggleTrainJobHibernation } from '../../api';
17+
import useClusterTrainingRuntime from '../../hooks/useClusterTrainingRuntime';
1718

1819
type TrainingJobTableRowProps = {
1920
job: TrainJobKind;
@@ -35,7 +36,25 @@ const TrainingJobTableRow: React.FC<TrainingJobTableRowProps> = ({
3536
const [isToggling, setIsToggling] = React.useState(false);
3637

3738
const displayName = getDisplayNameFromK8sResource(job);
38-
const nodesCount = job.spec.trainer.numNodes || 0;
39+
40+
// Fetch ClusterTrainingRuntime if trainer spec is not available
41+
const runtimeName =
42+
job.spec.runtimeRef.kind === 'ClusterTrainingRuntime' ? job.spec.runtimeRef.name : null;
43+
const { clusterTrainingRuntime, loaded: runtimeLoaded } = useClusterTrainingRuntime(
44+
!job.spec.trainer ? runtimeName : null,
45+
);
46+
47+
// Get numNodes from trainer spec or ClusterTrainingRuntime
48+
const nodesCount = React.useMemo(() => {
49+
if (job.spec.trainer?.numNodes) {
50+
return job.spec.trainer.numNodes;
51+
}
52+
if (runtimeLoaded && clusterTrainingRuntime?.spec.mlPolicy?.numNodes) {
53+
return clusterTrainingRuntime.spec.mlPolicy.numNodes;
54+
}
55+
return 0;
56+
}, [job.spec.trainer?.numNodes, runtimeLoaded, clusterTrainingRuntime]);
57+
3958
const localQueueName = job.metadata.labels?.['kueue.x-k8s.io/queue-name'];
4059

4160
const status = jobStatus || getTrainingJobStatusSync(job);

packages/model-training/src/global/trainingJobList/const.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ export const columns: SortableData<TrainJobKind>[] = [
2424
label: 'Nodes',
2525
width: 15,
2626
sortable: (a: TrainJobKind, b: TrainJobKind): number => {
27-
const aNodes = a.spec.trainer.numNodes || 0;
28-
const bNodes = b.spec.trainer.numNodes || 0;
27+
const aNodes = a.spec.trainer?.numNodes || 0;
28+
const bNodes = b.spec.trainer?.numNodes || 0;
2929

3030
return aNodes - bNodes;
3131
},

0 commit comments

Comments
 (0)