Skip to content

Commit cb136bc

Browse files
Merge pull request AI-Hypercomputer#2605 from AI-Hypercomputer:hengtaoguo-format
PiperOrigin-RevId: 828583732
2 parents 2dc5ffc + 24fd031 commit cb136bc

11 files changed

+1147
-1135
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ repos:
99
args:
1010
- '-w'
1111
- '--skip="*.txt,pylintrc,.*,src/MaxText/assets/*"'
12-
- '-L ND,nd,sems,TE,ROUGE,rouge,astroid'
12+
- '-L ND,nd,sems,TE,ROUGE,rouge,astroid,dout'
1313
- '.'
1414
additional_dependencies:
1515
- tomli

docs/explanations/performance_metrics.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,4 +101,4 @@ This shows any of step time, tokens/s or MFU can be used to determine how long t
101101

102102
## Why not hardware flops?
103103

104-
Hardware (e.g., XLA reported) FLOPs do not accurately reflect computation efficiency as they depend on the program / implementation, not just on the model and its inherent computations (higher hardware FLOPs does not necessarily mean less room for improvement). For example, they include remat and potentially auxilliary operations (such as reshaping for dropping moe [here](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/layers/moe.py#L1544)), which are an implementation detail and not part of the model. In addition, XLA reported FLOPs may not be accurate with pallas kernels. Hardware flops utilization is not (inversely) proportional to step time as opposed to MFU, since hardware flops can change with implementation details like remat policies.
104+
Hardware (e.g., XLA reported) FLOPs do not accurately reflect computation efficiency as they depend on the program / implementation, not just on the model and its inherent computations (higher hardware FLOPs does not necessarily mean less room for improvement). For example, they include remat and potentially auxiliary operations (such as reshaping for dropping moe [here](https://github.com/AI-Hypercomputer/maxtext/blob/fafdeaa14183a8f5ca7b9f7b7542ce1655237574/src/MaxText/layers/moe.py#L1544)), which are an implementation detail and not part of the model. In addition, XLA reported FLOPs may not be accurate with pallas kernels. Hardware flops utilization is not (inversely) proportional to step time as opposed to MFU, since hardware flops can change with implementation details like remat policies.

docs/guides/understand_logs_and_metrics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ Per train step:
188188

189189
In this example, given `model=deepseek2-16b`, `per_device_batch_size=24`, `max_target_length=2048`, and no gradient accumulation, we have $\text{model tflop per device} \approx 764.67$.
190190
- 94.54% of the TFLOPs are attributed to learnable weight and 5.46% are attributed to attention.
191-
- As you will see next, this number is important for calculating performace metrics, such as TFLOP/s/device and Model FLOPs Utilization (MFU).
191+
- As you will see next, this number is important for calculating performance metrics, such as TFLOP/s/device and Model FLOPs Utilization (MFU).
192192

193193
You can find more information about model FLOPs and MFU in the [Performance Metrics](performance-metrics) topic.
194194

@@ -231,7 +231,7 @@ $$\text{tflop/s/device} = \frac{\text{model tflop per device}}{\text{measured st
231231

232232
$$\text{MFU} = \frac{\text{tflop/s/device}}{\text{peak hardware tflop/s}}$$
233233

234-
For TPU v5p, $\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explaination with small batch size and sequence length, so the MFU is not optimal.
234+
For TPU v5p, $\text{peak hardware tflop/s}=459$. Thus, $134.924 / 459 = 29.40$%. Note this is an example for explanation with small batch size and sequence length, so the MFU is not optimal.
235235

236236
**Tokens per second per device (throughput)**
237237

docs/tutorials/grpo_with_pathways.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,5 +66,5 @@ The overview of the demo script ~/maxtext/src/MaxText/examples/grpo_llama3_1_70b
6666

6767
1. We load a policy model and a reference model. Both are copies of `Llama3.1-70b-Instruct`.
6868
2. Evaluate the policy model's performance on GSM8K math reasoning benchmark.
69-
3. Train the policy model using GRPO with potentially different meshes for trainer and rollout dependending on the parameters `TRAINER_DEVICES_FRACTION` and `SAMPLER_DEVICES_FRACTION`. If we set both of these to `1.0`, the entire (same) mesh will be used for both trainer and rollout. If we set say `TRAINER_DEVICES_FRACTION=0.5` and `SAMPLER_DEVICES_FRACTION=0.5`, the first half of the devices will be used for trainer and the second half will be used for rollout
69+
3. Train the policy model using GRPO with potentially different meshes for trainer and rollout depending on the parameters `TRAINER_DEVICES_FRACTION` and `SAMPLER_DEVICES_FRACTION`. If we set both of these to `1.0`, the entire (same) mesh will be used for both trainer and rollout. If we set say `TRAINER_DEVICES_FRACTION=0.5` and `SAMPLER_DEVICES_FRACTION=0.5`, the first half of the devices will be used for trainer and the second half will be used for rollout
7070
4. Evaluate the policy model's performance on GSM8K math reasoning benchmark after the post-training with GRPO.

src/MaxText/estimator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def tensor_score(tensor_name: str, config) -> tuple:
7070
7171
The score is used to prioritize which tensors to offload/remat first. Tensors
7272
with a higher score are rematerialized later. The scoring is based on tensor
73-
arithmatic intensity and memory size, with larger tensors getting lower scores
73+
arithmetic intensity and memory size, with larger tensors getting lower scores
7474
(higher priority for remat).
7575
7676
Args:
@@ -188,19 +188,19 @@ def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int:
188188
print(f"No OOM at maximum batch size {max_pdb}.")
189189
return max_pdb
190190

191-
low, high, ans = min_pdb, max_pdb, min_pdb
191+
low, high, result = min_pdb, max_pdb, min_pdb
192192
while low <= high:
193193
mid = (low + high) // 2
194194
if mid < min_pdb:
195195
low = mid + 1
196196
continue
197197

198198
if not is_oom(base_argv, policy, mid):
199-
ans = mid
199+
result = mid
200200
low = mid + 1
201201
else:
202202
high = mid - 1
203-
return ans
203+
return result
204204

205205

206206
def is_oom(base_argv, policy: dict, pdb: int) -> bool:

src/MaxText/examples/grpo_llama3_1_70b_demo_pw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
# for vLLM we can skip JAX precompilation with this flag, it makes startup faster
8383
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
8484

85-
# add the parent directory (two levels up to say ~/HOME/maxtext) to sys.path if currenlt runnig from
85+
# add the parent directory (two levels up to say ~/HOME/maxtext) to sys.path if currenlt running from
8686
# ~/HOME/maxtext/MaxText/examples
8787

8888
# Get the directory of the current script

src/MaxText/examples/grpo_llama3_1_8b_demo_pw.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
# for vLLM we can skip JAX precompilation with this flag, it makes startup faster
8383
os.environ["SKIP_JAX_PRECOMPILE"] = "1"
8484

85-
# add the parent directory (two levels up to say ~/HOME/maxtext) to sys.path if currenlt runnig from
85+
# add the parent directory (two levels up to say ~/HOME/maxtext) to sys.path if currenlt running from
8686
# ~/HOME/maxtext/MaxText/examples
8787

8888
# Get the directory of the current script

0 commit comments

Comments
 (0)