Skip to content

Commit 1b8351c

Browse files
authored
Add logic to extract TPU requests and add to container spec (#40)
1 parent 33591fb commit 1b8351c

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

packages/k8s/src/hooks/prepare-job.ts

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ import {
3636
getNumberOfHost,
3737
sleep,
3838
generateServicesName,
39-
getEntryPointAndArgs
39+
getEntryPointAndArgs,
40+
getTpuRequest
4041
} from '../k8s/utils'
4142
import {
4243
CONTAINER_EXTENSION_PREFIX,
@@ -308,6 +309,7 @@ export function createContainerSpec(
308309
}
309310
}
310311

312+
let tpuRequest = 0
311313
if (!jobContainer && createOptions && createOptions?.length > 0) {
312314
core.debug(
313315
`overriding service container ${JSON.stringify(
@@ -320,6 +322,8 @@ export function createContainerSpec(
320322
container.entryPoint = entryPointAndArgs[0]
321323
container.entryPointArgs = entryPointAndArgs.slice(1)
322324
}
325+
tpuRequest = getTpuRequest(createOptions)
326+
core.debug(`TPU request from service container is ${tpuRequest}`)
323327
}
324328

325329
const podContainer = {
@@ -339,6 +343,18 @@ export function createContainerSpec(
339343
podContainer.args = fixArgs(container.entryPointArgs)
340344
}
341345

346+
if (tpuRequest > 0) {
347+
core.debug(`assigning ${tpuRequest} to podContainer`)
348+
podContainer.resources = {
349+
limits: {
350+
'google.com/tpu': String(tpuRequest)
351+
},
352+
requests: {
353+
'google.com/tpu': String(tpuRequest)
354+
}
355+
}
356+
}
357+
342358
podContainer.env = []
343359
for (const [key, value] of Object.entries(
344360
container['environmentVariables']

packages/k8s/src/k8s/utils.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,14 +467,25 @@ export function createScriptExecutorContainer(
467467
}
468468

469469
const entrypointRegex = /--entrypoint=\[(.*?)\]/
470+
const tpuRegex = /--tpu=([0-9]+)/
470471

471472
export function getEntryPointAndArgs(createOptions: string): string[] {
472473
const match = createOptions.match(entrypointRegex)
473474

474475
if (!match || match[1] === undefined) {
475-
core.debug(`no match for createoptions ${createOptions}`)
476+
core.debug(`no entrypoint found for createOptions ${createOptions}`)
476477
return []
477478
}
478479

479480
return match[1].split(',').map(item => item.trim())
480481
}
482+
483+
export function getTpuRequest(createOptions: string): number {
484+
const match = createOptions.match(tpuRegex)
485+
if (!match || match[1] === undefined) {
486+
core.debug(`no tpu found for ${createOptions}`)
487+
return 0
488+
}
489+
490+
return Number(match[1])
491+
}

packages/k8s/tests/k8s-utils-test.ts

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ import {
1313
getNumberOfHost,
1414
ENV_NUMBER_OF_HOSTS,
1515
generateServicesName,
16-
getEntryPointAndArgs
16+
getEntryPointAndArgs,
17+
getTpuRequest
1718
} from '../src/k8s/utils'
1819
import * as k8s from '@kubernetes/client-node'
1920
import { TestHelper } from './test-setup'
@@ -389,6 +390,13 @@ describe('k8s utils', () => {
389390
})
390391
})
391392

393+
describe('getTpuRequest', () => {
394+
it('should return correct number of TPU requested', () => {
395+
expect(getTpuRequest('--tpu=4')).toEqual(4)
396+
expect(getTpuRequest('--nothing')).toEqual(0)
397+
})
398+
})
399+
392400
describe('create script executor container', () => {
393401
it('should install script executor at the volume mount location', () => {
394402
const executorVolumeMount = new k8s.V1VolumeMount()

0 commit comments

Comments
 (0)