-
Notifications
You must be signed in to change notification settings - Fork 86
Benchmark: Model benchmark - deterministic training support #731
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,150 @@ | ||||||||||
| # Copyright (c) Microsoft Corporation. | ||||||||||
| # Licensed under the MIT license. | ||||||||||
|
|
||||||||||
| """Unified PyTorch deterministic training example for all supported models. | ||||||||||
|
|
||||||||||
| Deterministic metrics (loss, activation mean) are automatically stored in results | ||||||||||
| when --enable_determinism flag is enabled. | ||||||||||
|
|
||||||||||
| To compare deterministic results between runs, use the `sb result diagnosis` command | ||||||||||
| with a baseline file and comparison rules. See the SuperBench documentation for details. | ||||||||||
|
|
||||||||||
| Example workflow: | ||||||||||
| 1. Run first benchmark (creates outputs/<timestamp>/results-summary.jsonl): | ||||||||||
| python3 examples/benchmarks/pytorch_deterministic_example.py \ | ||||||||||
| --model resnet101 --enable_determinism --deterministic_seed 42 | ||||||||||
|
|
||||||||||
| 2. Generate baseline from results: | ||||||||||
| sb result generate-baseline --data-file outputs/<timestamp>/results-summary.jsonl \ | ||||||||||
| --summary-rule-file summary-rules.yaml --output-dir outputs/<timestamp> | ||||||||||
|
|
||||||||||
| 3. Run second benchmark: | ||||||||||
| python3 examples/benchmarks/pytorch_deterministic_example.py \ | ||||||||||
| --model resnet101 --enable_determinism --deterministic_seed 42 | ||||||||||
|
|
||||||||||
| 4. Compare runs with diagnosis: | ||||||||||
| sb result diagnosis --data-file outputs/<run2-timestamp>/results-summary.jsonl \ | ||||||||||
| --rule-file rules.yaml --baseline-file outputs/<run1-timestamp>/baseline.json | ||||||||||
|
|
||||||||||
| Note: CUBLAS_WORKSPACE_CONFIG is now automatically set by the code when determinism is enabled. | ||||||||||
| """ | ||||||||||
|
|
||||||||||
| import argparse | ||||||||||
| import json | ||||||||||
| import socket | ||||||||||
| from datetime import datetime | ||||||||||
| from pathlib import Path | ||||||||||
| from superbench.benchmarks import BenchmarkRegistry, Framework | ||||||||||
| from superbench.common.utils import logger | ||||||||||
|
|
||||||||||
| MODEL_CHOICES = [ | ||||||||||
| 'bert-large', | ||||||||||
| 'gpt2-small', | ||||||||||
| 'llama2-7b', | ||||||||||
| 'mixtral-8x7b', | ||||||||||
| 'resnet101', | ||||||||||
| 'lstm', | ||||||||||
| ] | ||||||||||
|
|
||||||||||
| DEFAULT_PARAMS = { | ||||||||||
| 'bert-large': | ||||||||||
| '--batch_size 1 --seq_len 64 --num_warmup 1 --num_steps 200 --precision float32 ' | ||||||||||
| '--model_action train --check_frequency 20', | ||||||||||
| 'gpt2-small': | ||||||||||
| '--batch_size 1 --num_steps 300 --num_warmup 1 --seq_len 128 --precision float32 ' | ||||||||||
| '--model_action train --check_frequency 20', | ||||||||||
| 'llama2-7b': | ||||||||||
| '--batch_size 1 --num_steps 300 --num_warmup 1 --seq_len 512 --precision float32 --model_action train ' | ||||||||||
| '--check_frequency 20', | ||||||||||
| 'mixtral-8x7b': | ||||||||||
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| '--hidden_size 4096 --num_hidden_layers 32 --num_attention_heads 32 --intermediate_size 14336 ' | ||||||||||
| '--num_key_value_heads 8 --max_position_embeddings 32768 --router_aux_loss_coef 0.02 ' | ||||||||||
| '--check_frequency 20', | ||||||||||
| 'resnet101': | ||||||||||
| '--batch_size 1 --precision float32 --num_warmup 1 --num_steps 120 --sample_count 8192 ' | ||||||||||
| '--pin_memory --model_action train --check_frequency 20', | ||||||||||
| 'lstm': | ||||||||||
| '--batch_size 1 --num_steps 100 --num_warmup 2 --seq_len 64 --precision float32 ' | ||||||||||
| '--model_action train --check_frequency 30', | ||||||||||
| } | ||||||||||
|
|
||||||||||
|
|
||||||||||
| def main(): | ||||||||||
| """Main function for determinism example file.""" | ||||||||||
| parser = argparse.ArgumentParser(description='Unified PyTorch deterministic training example.') | ||||||||||
| parser.add_argument('--model', type=str, choices=MODEL_CHOICES, required=True, help='Model to run.') | ||||||||||
| parser.add_argument( | ||||||||||
| '--enable_determinism', | ||||||||||
| action='store_true', | ||||||||||
| help='Enable deterministic mode for reproducible results.', | ||||||||||
| ) | ||||||||||
| parser.add_argument( | ||||||||||
| '--deterministic_seed', | ||||||||||
| type=int, | ||||||||||
| default=None, | ||||||||||
| help='Seed for deterministic training.', | ||||||||||
| ) | ||||||||||
| args = parser.parse_args() | ||||||||||
|
|
||||||||||
| parameters = DEFAULT_PARAMS[args.model] | ||||||||||
| if args.enable_determinism: | ||||||||||
| parameters += ' --enable_determinism' | ||||||||||
| if args.deterministic_seed is not None: | ||||||||||
| parameters += f' --deterministic_seed {args.deterministic_seed}' | ||||||||||
|
|
||||||||||
| context = BenchmarkRegistry.create_benchmark_context(args.model, parameters=parameters, framework=Framework.PYTORCH) | ||||||||||
| benchmark = BenchmarkRegistry.launch_benchmark(context) | ||||||||||
| logger.info(f'Benchmark finished. Return code: {benchmark.return_code}') | ||||||||||
|
|
||||||||||
| # Create timestamped output directory | ||||||||||
| timestamp = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') | ||||||||||
| output_dir = Path('outputs') / timestamp | ||||||||||
| output_dir.mkdir(parents=True, exist_ok=True) | ||||||||||
|
|
||||||||||
| # Parse benchmark results | ||||||||||
| benchmark_results = json.loads(benchmark.serialized_result) | ||||||||||
| benchmark_name = benchmark_results.get('name', f'pytorch-{args.model}') | ||||||||||
|
|
||||||||||
| # Convert to results-summary.jsonl format (flattened keys) | ||||||||||
| # Use format compatible with sb result commands: model-benchmarks:<category>/<benchmark>/<metric> | ||||||||||
| summary = {} | ||||||||||
| prefix = f'model-benchmarks:example:determinism/{benchmark_name}' | ||||||||||
| if 'result' in benchmark_results: | ||||||||||
| for metric, values in benchmark_results['result'].items(): | ||||||||||
| # Use first value if it's a list | ||||||||||
| val = values[0] if isinstance(values, list) else values | ||||||||||
| # Add _rank0 suffix to deterministic metrics for compatibility with rules | ||||||||||
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||
| if metric.startswith('deterministic_'): | ||||||||||
|
Comment on lines
+116
to
+117
|
||||||||||
| # Add _rank0 suffix to deterministic metrics for compatibility with rules | |
| if metric.startswith('deterministic_'): | |
| # Add _rank0 suffix to deterministic metrics that don't already have a rank suffix | |
| if metric.startswith('deterministic_') and '_rank' not in metric: |
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -150,6 +150,33 @@ def generate_baseline(self, algo, aggregated_df, diagnosis_rule_file, baseline): | |
| aggregated_df[metrics[index]] = out[1] | ||
| return baseline | ||
|
|
||
| def _format_metric_value(self, metric, val, digit): | ||
| """Format a single baseline metric value based on its type. | ||
|
|
||
| Args: | ||
| metric (str): the metric name. | ||
| val: the metric value. | ||
| digit (int): the number of digits after the decimal point. | ||
|
|
||
| Returns: | ||
| The formatted metric value. | ||
| """ | ||
| if metric not in self._raw_data_df: | ||
| return val | ||
| sample = self._raw_data_df[metric].iloc[0] | ||
| if isinstance(sample, float): | ||
| # Keep full precision for deterministic metrics to avoid false positives in diagnosis | ||
| if 'deterministic' in metric: | ||
| return float(val) | ||
|
Comment on lines
+166
to
+170
|
||
| return f'%.{digit}g' % val if abs(val) < 1 else f'%.{digit}f' % val | ||
| if isinstance(sample, int): | ||
| return int(val) | ||
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| try: | ||
| return float(val) | ||
| except Exception as e: | ||
| logger.error('Analyzer: {} baseline is not numeric, msg: {}'.format(metric, str(e))) | ||
| return val | ||
|
|
||
| def run( | ||
| self, raw_data_file, summary_rule_file, diagnosis_rule_file, pre_baseline_file, algorithm, output_dir, digit=2 | ||
| ): | ||
|
|
@@ -174,19 +201,9 @@ def run( | |
| # generate baseline accordint to rules in diagnosis and fix threshold outlier detection method | ||
| baseline = self.generate_baseline(algorithm, self._raw_data_df, diagnosis_rule_file, baseline) | ||
| for metric in baseline: | ||
| val = baseline[metric] | ||
| if metric in self._raw_data_df: | ||
| if isinstance(self._raw_data_df[metric].iloc[0], float): | ||
| baseline[metric] = f'%.{digit}g' % val if abs(val) < 1 else f'%.{digit}f' % val | ||
| elif isinstance(self._raw_data_df[metric].iloc[0], int): | ||
| baseline[metric] = int(val) | ||
| else: | ||
| try: | ||
| baseline[metric] = float(val) | ||
| except Exception as e: | ||
| logger.error('Analyzer: {} baseline is not numeric, msg: {}'.format(metric, str(e))) | ||
| baseline[metric] = self._format_metric_value(metric, baseline[metric], digit) | ||
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| baseline = json.dumps(baseline, indent=2, sort_keys=True) | ||
| baseline = re.sub(r': \"(\d+.?\d*)\"', r': \1', baseline) | ||
| baseline = re.sub(r': \"(-?\d+\.?\d*)\"', r': \1', baseline) | ||
Aishwarya-Tonpe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| with (Path(output_dir) / 'baseline.json').open('w') as f: | ||
| f.write(baseline) | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.