Skip to content

Commit 5736452

Browse files
committed
fix: batch-transform deploy/test/clean UX improvements
1 parent 60656d1 commit 5736452

19 files changed

Lines changed: 2890 additions & 6 deletions

generators/app/index.js

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,47 @@ export default class extends Generator {
204204
type: Number,
205205
description: 'Max concurrent invocations per instance for async inference (default: 1)'
206206
});
207+
208+
// Batch transform options
209+
this.option('batch-input-path', {
210+
type: String,
211+
description: 'S3 input path for batch transform data'
212+
});
213+
214+
this.option('batch-output-path', {
215+
type: String,
216+
description: 'S3 output path for batch transform results'
217+
});
218+
219+
this.option('batch-instance-count', {
220+
type: Number,
221+
description: 'Number of instances for batch transform job (default: 1)'
222+
});
223+
224+
this.option('batch-split-type', {
225+
type: String,
226+
description: 'Input data split type: Line, RecordIO, None (default: Line)'
227+
});
228+
229+
this.option('batch-strategy', {
230+
type: String,
231+
description: 'Batch strategy: MultiRecord, SingleRecord (default: MultiRecord)'
232+
});
233+
234+
this.option('batch-join-source', {
235+
type: String,
236+
description: 'Join source: Input, None (default: None)'
237+
});
238+
239+
this.option('batch-max-concurrent', {
240+
type: Number,
241+
description: 'Max concurrent transforms per instance (default: 1)'
242+
});
243+
244+
this.option('batch-max-payload', {
245+
type: Number,
246+
description: 'Max payload size in MB, 0-100 (default: 6)'
247+
});
207248
}
208249

209250
/**
@@ -447,6 +488,12 @@ export default class extends Generator {
447488
architecture = this.answers.framework === 'transformers' ? 'transformers' : 'http';
448489
}
449490

491+
// Exclude sample_model directory when not needed
492+
// Transformers and diffusors don't use sample models (they load from HuggingFace Hub)
493+
if (!this.answers.includeSampleModel || architecture === 'transformers' || architecture === 'diffusors') {
494+
ignorePatterns.push('**/sample_model/**');
495+
}
496+
450497
// Always exclude triton and diffusors source directories from initial copy (they are sources, not output)
451498
ignorePatterns.push('**/triton/**');
452499
ignorePatterns.push('**/diffusors/**');

generators/app/lib/config-manager.js

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,94 @@ export default class ConfigManager {
600600
required: false,
601601
default: 1,
602602
valueSpace: 'bounded'
603+
},
604+
batchInputPath: {
605+
cliOption: 'batch-input-path',
606+
envVar: 'ML_BATCH_INPUT_PATH',
607+
configFile: true,
608+
packageJson: false,
609+
mcp: true,
610+
promptable: true,
611+
required: false,
612+
default: null,
613+
valueSpace: 'unbounded'
614+
},
615+
batchOutputPath: {
616+
cliOption: 'batch-output-path',
617+
envVar: 'ML_BATCH_OUTPUT_PATH',
618+
configFile: true,
619+
packageJson: false,
620+
mcp: true,
621+
promptable: true,
622+
required: false,
623+
default: null,
624+
valueSpace: 'unbounded'
625+
},
626+
batchInstanceCount: {
627+
cliOption: 'batch-instance-count',
628+
envVar: null,
629+
configFile: true,
630+
packageJson: false,
631+
mcp: false,
632+
promptable: true,
633+
required: false,
634+
default: 1,
635+
valueSpace: 'bounded'
636+
},
637+
batchSplitType: {
638+
cliOption: 'batch-split-type',
639+
envVar: null,
640+
configFile: true,
641+
packageJson: false,
642+
mcp: false,
643+
promptable: true,
644+
required: false,
645+
default: 'Line',
646+
valueSpace: 'bounded'
647+
},
648+
batchStrategy: {
649+
cliOption: 'batch-strategy',
650+
envVar: null,
651+
configFile: true,
652+
packageJson: false,
653+
mcp: false,
654+
promptable: true,
655+
required: false,
656+
default: 'MultiRecord',
657+
valueSpace: 'bounded'
658+
},
659+
batchJoinSource: {
660+
cliOption: 'batch-join-source',
661+
envVar: null,
662+
configFile: true,
663+
packageJson: false,
664+
mcp: false,
665+
promptable: true,
666+
required: false,
667+
default: 'None',
668+
valueSpace: 'bounded'
669+
},
670+
batchMaxConcurrentTransforms: {
671+
cliOption: 'batch-max-concurrent',
672+
envVar: null,
673+
configFile: true,
674+
packageJson: false,
675+
mcp: false,
676+
promptable: true,
677+
required: false,
678+
default: 1,
679+
valueSpace: 'bounded'
680+
},
681+
batchMaxPayloadInMB: {
682+
cliOption: 'batch-max-payload',
683+
envVar: null,
684+
configFile: true,
685+
packageJson: false,
686+
mcp: false,
687+
promptable: true,
688+
required: false,
689+
default: 6,
690+
valueSpace: 'bounded'
603691
}
604692
};
605693
}

generators/app/lib/prompt-runner.js

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import {
2222
infraRegionAndTargetPrompts,
2323
infraInstancePrompts,
2424
infraAsyncPrompts,
25+
infraBatchTransformPrompts,
2526
infraHyperPodPrompts,
2627
infraBuildPrompts,
2728
projectPrompts,
@@ -69,10 +70,11 @@ export default class PromptRunner {
6970
await this._queryMcpForRegion({}, explicitConfig);
7071
const regionAndTargetAnswers = await this._runPhase(infraRegionAndTargetPrompts, {}, explicitConfig, existingConfig);
7172

72-
// 1b. Instance type — query MCP and prompt for managed-inference, async-inference, and hyperpod-eks
73+
// 1b. Instance type — query MCP and prompt for managed-inference, async-inference, batch-transform, and hyperpod-eks
7374
let instanceAnswers = {};
7475
if (regionAndTargetAnswers.deploymentTarget === 'managed-inference' ||
7576
regionAndTargetAnswers.deploymentTarget === 'async-inference' ||
77+
regionAndTargetAnswers.deploymentTarget === 'batch-transform' ||
7678
regionAndTargetAnswers.deploymentTarget === 'hyperpod-eks') {
7779
await this._queryMcpForInstance({}, explicitConfig);
7880
const mcpInstanceChoices = this.configManager?.mcpChoices?.instanceType;
@@ -89,6 +91,17 @@ export default class PromptRunner {
8991
asyncAnswers = await this._runPhase(infraAsyncPrompts, { ...regionAndTargetAnswers }, explicitConfig, existingConfig);
9092
}
9193

94+
// 1b-batch. Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
95+
let batchTransformAnswers = {};
96+
if (regionAndTargetAnswers.deploymentTarget === 'batch-transform') {
97+
batchTransformAnswers = await this._runPhase(
98+
infraBatchTransformPrompts,
99+
{ ...regionAndTargetAnswers },
100+
explicitConfig,
101+
existingConfig
102+
);
103+
}
104+
92105
// 1c. HyperPod prompts — only query MCP and prompt when deployment target is hyperpod-eks
93106
let hyperPodAnswers = {};
94107
if (regionAndTargetAnswers.deploymentTarget === 'hyperpod-eks') {
@@ -106,6 +119,7 @@ export default class PromptRunner {
106119
...regionAndTargetAnswers,
107120
...instanceAnswers,
108121
...asyncAnswers,
122+
...batchTransformAnswers,
109123
...hyperPodAnswers,
110124
...buildAnswers
111125
};

generators/app/lib/prompts.js

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ const infraRegionAndTargetPrompts = [
611611
choices: [
612612
{ name: 'SageMaker Managed Inference - Real Time', value: 'managed-inference' },
613613
{ name: 'SageMaker Managed Inference - Async', value: 'async-inference' },
614+
{ name: 'SageMaker Managed Inference - Batch', value: 'batch-transform' },
614615
{ name: 'SageMaker HyperPod - EKS', value: 'hyperpod-eks' }
615616
],
616617
default: 'managed-inference'
@@ -622,7 +623,7 @@ const infraInstancePrompts = [
622623
{
623624
type: 'list',
624625
name: 'instanceType',
625-
when: answers => answers.deploymentTarget === 'managed-inference' || answers.deploymentTarget === 'async-inference' || answers.deploymentTarget === 'hyperpod-eks',
626+
when: answers => answers.deploymentTarget === 'managed-inference' || answers.deploymentTarget === 'async-inference' || answers.deploymentTarget === 'batch-transform' || answers.deploymentTarget === 'hyperpod-eks',
626627
message: (answers) => {
627628
const framework = answers.framework || answers.deploymentConfig?.split('-')[0];
628629

@@ -854,6 +855,80 @@ const infraAsyncPrompts = [
854855
}
855856
];
856857

858+
/**
859+
* Sub-phase: Batch transform-specific prompts (only when deploymentTarget === 'batch-transform')
860+
* Requirements: 2.1, 2.2, 2.4, 2.5, 2.6, 2.7, 2.8, 2.9
861+
*/
862+
const infraBatchTransformPrompts = [
863+
{
864+
type: 'input',
865+
name: 'batchInputPath',
866+
message: 'S3 input path for batch transform data (leave empty for default: s3://ml-container-creator-batch-{region}-{account-id}/{project-name}/input/):',
867+
when: answers => answers.deploymentTarget === 'batch-transform'
868+
},
869+
{
870+
type: 'input',
871+
name: 'batchOutputPath',
872+
message: 'S3 output path for batch transform results (leave empty for default: s3://ml-container-creator-batch-{region}-{account-id}/{project-name}/output/):',
873+
when: answers => answers.deploymentTarget === 'batch-transform'
874+
},
875+
{
876+
type: 'number',
877+
name: 'batchInstanceCount',
878+
message: 'How many instances should run the batch job in parallel?',
879+
default: 1,
880+
when: answers => answers.deploymentTarget === 'batch-transform'
881+
},
882+
{
883+
type: 'list',
884+
name: 'batchSplitType',
885+
message: 'Input file format — how should SageMaker read your input files?',
886+
choices: [
887+
{ name: 'Line — one record per line (JSON lines, CSV)', value: 'Line' },
888+
{ name: 'RecordIO — Amazon RecordIO format', value: 'RecordIO' },
889+
{ name: 'None — send each file as a single request', value: 'None' }
890+
],
891+
default: 'Line',
892+
when: answers => answers.deploymentTarget === 'batch-transform'
893+
},
894+
{
895+
type: 'list',
896+
name: 'batchStrategy',
897+
message: 'How many records should be sent per inference request?',
898+
choices: [
899+
{ name: 'MultiRecord — batch multiple records per request (higher throughput)', value: 'MultiRecord' },
900+
{ name: 'SingleRecord — one record per request (simpler, more predictable)', value: 'SingleRecord' }
901+
],
902+
default: 'MultiRecord',
903+
when: answers => answers.deploymentTarget === 'batch-transform'
904+
},
905+
{
906+
type: 'list',
907+
name: 'batchJoinSource',
908+
message: 'Include original input data alongside predictions in the output?',
909+
choices: [
910+
{ name: 'No — output predictions only', value: 'None' },
911+
{ name: 'Yes — merge input with predictions (useful for traceability)', value: 'Input' }
912+
],
913+
default: 'None',
914+
when: answers => answers.deploymentTarget === 'batch-transform'
915+
},
916+
{
917+
type: 'number',
918+
name: 'batchMaxConcurrentTransforms',
919+
message: 'Max concurrent inference requests per instance?',
920+
default: 1,
921+
when: answers => answers.deploymentTarget === 'batch-transform'
922+
},
923+
{
924+
type: 'number',
925+
name: 'batchMaxPayloadInMB',
926+
message: 'Max request payload size in MB (0-100)?',
927+
default: 6,
928+
when: answers => answers.deploymentTarget === 'batch-transform'
929+
}
930+
];
931+
857932
// Combined view for tests and backward compatibility
858933
const infrastructurePrompts = [
859934
...infraRegionAndTargetPrompts,
@@ -977,6 +1052,7 @@ export {
9771052
infraRegionAndTargetPrompts,
9781053
infraInstancePrompts,
9791054
infraAsyncPrompts,
1055+
infraBatchTransformPrompts,
9801056
infraHyperPodPrompts,
9811057
infraBuildPrompts,
9821058
projectPrompts,

generators/app/lib/template-manager.js

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ export default class TemplateManager {
6464
'diffusors-vllm-omni'
6565
],
6666
buildTargets: ['codebuild'],
67-
deploymentTargets: ['managed-inference', 'async-inference', 'hyperpod-eks'],
67+
deploymentTargets: ['managed-inference', 'async-inference', 'batch-transform', 'hyperpod-eks'],
6868
testTypes: ['local-model-cli', 'local-model-server', 'hosted-model-endpoint'],
6969
awsRegions: [
7070
'us-east-1', 'us-east-2', 'us-west-1', 'us-west-2',
@@ -131,6 +131,9 @@ export default class TemplateManager {
131131

132132
// Validate async inference specific fields
133133
this._validateAsyncConfig()
134+
135+
// Validate batch transform specific fields
136+
this._validateBatchTransformConfig()
134137

135138
// Validate instance type format (ml.*.*) - only for managed-inference
136139
if (this.answers.instanceType && this.answers.instanceType !== 'custom') {
@@ -229,6 +232,71 @@ export default class TemplateManager {
229232
}
230233
}
231234

235+
/**
236+
* Validates batch transform specific configuration
237+
* @private
238+
* @throws {Error} If batch transform configuration is invalid
239+
*/
240+
_validateBatchTransformConfig() {
241+
if (this.answers.deploymentTarget !== 'batch-transform') return
242+
243+
// Validate S3 input path format if provided
244+
if (this.answers.batchInputPath && this.answers.batchInputPath.trim() !== '') {
245+
if (!this.answers.batchInputPath.startsWith('s3://')) {
246+
throw new Error('⚠️ batchInputPath must start with "s3://". Example: s3://my-bucket/input/')
247+
}
248+
}
249+
250+
// Validate S3 output path format if provided
251+
if (this.answers.batchOutputPath && this.answers.batchOutputPath.trim() !== '') {
252+
if (!this.answers.batchOutputPath.startsWith('s3://')) {
253+
throw new Error('⚠️ batchOutputPath must start with "s3://". Example: s3://my-bucket/output/')
254+
}
255+
}
256+
257+
// Validate instance count
258+
if (this.answers.batchInstanceCount !== undefined) {
259+
const val = this.answers.batchInstanceCount
260+
if (!Number.isInteger(val) || val < 1) {
261+
throw new Error('⚠️ batchInstanceCount must be an integer >= 1')
262+
}
263+
}
264+
265+
// Validate split type
266+
const validSplitTypes = ['Line', 'RecordIO', 'None']
267+
if (this.answers.batchSplitType && !validSplitTypes.includes(this.answers.batchSplitType)) {
268+
throw new Error(`⚠️ batchSplitType must be one of: ${validSplitTypes.join(', ')}`)
269+
}
270+
271+
// Validate batch strategy
272+
const validStrategies = ['MultiRecord', 'SingleRecord']
273+
if (this.answers.batchStrategy && !validStrategies.includes(this.answers.batchStrategy)) {
274+
throw new Error(`⚠️ batchStrategy must be one of: ${validStrategies.join(', ')}`)
275+
}
276+
277+
// Validate join source
278+
const validJoinSources = ['Input', 'None']
279+
if (this.answers.batchJoinSource && !validJoinSources.includes(this.answers.batchJoinSource)) {
280+
throw new Error(`⚠️ batchJoinSource must be one of: ${validJoinSources.join(', ')}`)
281+
}
282+
283+
// Validate max concurrent transforms
284+
if (this.answers.batchMaxConcurrentTransforms !== undefined) {
285+
const val = this.answers.batchMaxConcurrentTransforms
286+
if (!Number.isInteger(val) || val < 0) {
287+
throw new Error('⚠️ batchMaxConcurrentTransforms must be an integer >= 0')
288+
}
289+
}
290+
291+
// Validate max payload in MB
292+
if (this.answers.batchMaxPayloadInMB !== undefined) {
293+
const val = this.answers.batchMaxPayloadInMB
294+
if (!Number.isInteger(val) || val < 0 || val > 100) {
295+
throw new Error('⚠️ batchMaxPayloadInMB must be an integer between 0 and 100')
296+
}
297+
}
298+
}
299+
232300
/**
233301
* Validates GPU instance type requirement for GPU-requiring backends.
234302
* Called when deploymentConfig is present.

0 commit comments

Comments
 (0)