|
5 | 5 | * This ensures reproducible notebook generation with explicit ML decisions. |
6 | 6 | */ |
7 | 7 |
|
8 | | -import { z } from 'zod' |
| 8 | +import TOML from '@iarna/toml' |
9 | 9 | import { existsSync, readFileSync } from 'fs' |
10 | 10 | import { basename, dirname, join } from 'path' |
11 | | -import TOML from '@iarna/toml' |
12 | | -import type { CommandDefinition } from '../../types/commands' |
13 | | -import { success, error } from '../../lib/output' |
| 11 | +import { z } from 'zod' |
| 12 | +import { error, success } from '../../lib/output' |
14 | 13 | import { createTemplateEngine, PLATFORM_DISPLAY_NAMES } from '../../templates' |
| 14 | +import type { CommandDefinition } from '../../types/commands' |
15 | 15 | import type { PlatformId } from '../../types/platform' |
16 | 16 | import type { TemplateContext } from '../../types/template' |
17 | 17 |
|
@@ -158,8 +158,12 @@ function generateNotebook(config: TrainingConfig, configPath: string): string { |
158 | 158 | # kaggle: |
159 | 159 | # accelerator: gpu |
160 | 160 | # dataSources: |
161 | | -${dataSources.map(s => `# - type: ${s.type} |
162 | | -# name: ${s.name}`).join('\n')} |
| 161 | +${dataSources |
| 162 | + .map( |
| 163 | + (s) => `# - type: ${s.type} |
| 164 | +# name: ${s.name}` |
| 165 | + ) |
| 166 | + .join('\n')} |
163 | 167 | # docker_image: gcr.io/kaggle-gpu-images/python |
164 | 168 | # isGpuEnabled: true |
165 | 169 | # isInternetEnabled: true |
@@ -274,6 +278,7 @@ CONFIG = { |
274 | 278 |
|
275 | 279 | # Checkpoints |
276 | 280 | "save_total_limit": ${config.checkpoints.save_total_limit}, |
| 281 | + "save_only_model": ${config.checkpoints.save_optimizer === false ? 'True' : 'False'}, # True = skip optimizer.pt (saves ~4GB per checkpoint) |
277 | 282 | "load_best_at_end": ${config.checkpoints.load_best_at_end ? 'True' : 'False'}, |
278 | 283 |
|
279 | 284 | # Early stopping |
@@ -307,7 +312,7 @@ for k, v in CONFIG.items(): |
307 | 312 | # %% |
308 | 313 | # Dataset sources (in priority order) |
309 | 314 | DATASET_SOURCES = [ |
310 | | -${dataSources.map(s => ` "${s.path}",`).join('\n')} |
| 315 | +${dataSources.map((s) => ` "${s.path}",`).join('\n')} |
311 | 316 | ] |
312 | 317 |
|
313 | 318 | train_df = None |
@@ -424,6 +429,7 @@ training_args = Seq2SeqTrainingArguments( |
424 | 429 | save_strategy="steps", |
425 | 430 | save_steps=CONFIG["save_steps"], |
426 | 431 | save_total_limit=CONFIG["save_total_limit"], |
| 432 | + save_only_model=CONFIG["save_only_model"], # Skip optimizer.pt to save disk space |
427 | 433 | logging_steps=CONFIG["logging_steps"], |
428 | 434 | load_best_model_at_end=CONFIG["load_best_at_end"], |
429 | 435 | metric_for_best_model=CONFIG["metric_for_best_model"], |
@@ -569,8 +575,8 @@ function generateMetadata(config: TrainingConfig, outputPath: string): Record<st |
569 | 575 | enable_gpu: true, |
570 | 576 | enable_tpu: false, |
571 | 577 | enable_internet: true, |
572 | | - dataset_sources: dataSources.filter(s => s.type === 'dataset').map(s => s.name), |
573 | | - competition_sources: dataSources.filter(s => s.type === 'competition').map(s => s.name), |
| 578 | + dataset_sources: dataSources.filter((s) => s.type === 'dataset').map((s) => s.name), |
| 579 | + competition_sources: dataSources.filter((s) => s.type === 'competition').map((s) => s.name), |
574 | 580 | kernel_sources: [], |
575 | 581 | model_sources: [], |
576 | 582 | } |
@@ -634,10 +640,8 @@ Example config structure: see notebooks/kaggle/training.toml |
634 | 640 | } |
635 | 641 |
|
636 | 642 | // Determine output path |
637 | | - const outputPath = args.output || join( |
638 | | - dirname(args.path), |
639 | | - `${config.meta.name.toLowerCase().replace(/[^a-z0-9]/g, '_')}.py` |
640 | | - ) |
| 643 | + const outputPath = |
| 644 | + args.output || join(dirname(args.path), `${config.meta.name.toLowerCase().replace(/[^a-z0-9]/g, '_')}.py`) |
641 | 645 |
|
642 | 646 | // Generate notebook |
643 | 647 | const notebook = generateNotebook(config, args.path) |
@@ -675,12 +679,15 @@ Example config structure: see notebooks/kaggle/training.toml |
675 | 679 | if (!args.skipPreflight) { |
676 | 680 | // Import preflight dynamically to avoid circular deps |
677 | 681 | const { preflight } = await import('../preflight') |
678 | | - preflightResult = await preflight.run({ |
679 | | - path: outputPath, |
680 | | - platform: config.platform.target, |
681 | | - samples: 2000, |
682 | | - verbose: false, |
683 | | - }, ctx) |
| 682 | + preflightResult = await preflight.run( |
| 683 | + { |
| 684 | + path: outputPath, |
| 685 | + platform: config.platform.target, |
| 686 | + samples: 2000, |
| 687 | + verbose: false, |
| 688 | + }, |
| 689 | + ctx |
| 690 | + ) |
684 | 691 | } |
685 | 692 |
|
686 | 693 | return success({ |
|
0 commit comments