Skip to content

Triton Cache Error for MI250 #2146

@zixianwang2022

Description

@zixianwang2022

Bug description

I have been getting the following error when trying to run the 16B deepseek model.

[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0] Triton compilation failed: triton_red_fused_arange_bitwise_and_eq_ge_index_permute_sum_view_0
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0] def triton_red_fused_arange_bitwise_and_eq_ge_index_permute_sum_view_0(in_ptr0, out_ptr0, xnumel, r0_numel, XBLOCK : tl.constexpr, R0_BLOCK : tl.constexpr):
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     xnumel = 4096
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     r0_numel = 16384
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     rnumel = r0_numel
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     RBLOCK: tl.constexpr = R0_BLOCK
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     xoffset = tl.program_id(0) * XBLOCK
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     xmask = xindex < xnumel
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     r0_base = tl.arange(0, R0_BLOCK)[None, :]
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     rbase = r0_base
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     x1 = ((xindex // 32) % 32)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     x0 = (xindex % 32)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     x5 = xindex // 32
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     x2 = xindex // 1024
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     _tmp11 = tl.full([XBLOCK, R0_BLOCK], 0, tl.int64)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     x6 = xindex
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     for r0_offset in tl.range(0, r0_numel, R0_BLOCK, num_stages = 2):
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         r0_index = r0_offset + r0_base
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         r0_mask = r0_index < r0_numel
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         roffset = r0_offset
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         rindex = r0_index
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         r0_4 = r0_index // 128
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         r0_3 = (r0_index % 128)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp5 = tl.load(in_ptr0 + (r0_4 + 128*x5), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp6 = tl.load(in_ptr0 + (r0_3 + 128*x0 + 4096*x2), r0_mask & xmask, eviction_policy='evict_last', other=0.0)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp0 = r0_4 + 128*x1
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp1 = r0_3 + 128*x0
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp2 = tmp0 >= tmp1
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp3 = tl.full([1, 1], True, tl.int1)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp4 = tmp3 & tmp2
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp7 = tmp5 == tmp6
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp8 = tmp4 & tmp7
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp9 = tmp8.to(tl.int64)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp10 = tl.broadcast_to(tmp9, [XBLOCK, R0_BLOCK])
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         tmp12 = _tmp11 + tmp10
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]         _tmp11 = tl.where(r0_mask & xmask, tmp12, _tmp11)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     tmp11 = tl.sum(_tmp11, 1)[:, None]
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     tl.store(out_ptr0 + (x6), tmp11, xmask)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0] 
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0] metadata: {'signature': {'in_ptr0': '*i32', 'out_ptr0': '*i64', 'xnumel': 'i32', 'r0_numel': 'i32', 'XBLOCK': 'constexpr', 'R0_BLOCK': 'constexpr'}, 'device': 6, 'constants': {'XBLOCK': 64, 'R0_BLOCK': 4}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True, 'device_type': 'hip', 'num_warps': 8, 'num_stages': 1, 'debug': False, 'cc': 'gfx90a'}
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0] Traceback (most recent call last):
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 808, in _precompile_config
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     binary = triton.compile(*compile_args, **compile_kwargs)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/triton/compiler/compiler.py", line 267, in compile
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     res = CompiledKernel(src, metadata_group, hash)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/triton/compiler/compiler.py", line 423, in __init__
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     self.asm = AsmDict({
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]                        ^
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/triton/compiler/compiler.py", line 424, in <dictcomp>
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]                      ^^^^^^^^^^^^^^^^^
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/pathlib.py", line 1050, in read_bytes
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     with self.open(mode='rb') as f:
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]          ^^^^^^^^^^^^^^^^^^^^
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/pathlib.py", line 1044, in open
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]     return io.open(self, mode, buffering, encoding, errors, newline)
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:E1211 16:33:14.602000 1463070 site-packages/torch/_inductor/runtime/triton_heuristics.py:810] [0/0] FileNotFoundError: [Errno 2] No such file or directory: '/home1/zixianw4/triton_cache_255441/E5DAZZK6P4ZXQ3VKJPNIEWZ7P4HLUH2VX7O2QO7USGHM5XQCLM6A/triton_red_fused_arange_bitwise_and_eq_ge_index_permute_sum_view_0.hsaco'
[rank6]: Traceback (most recent call last):
[rank6]:   File "<frozen runpy>", line 198, in _run_module_as_main
[rank6]:   File "<frozen runpy>", line 88, in _run_code
[rank6]:   File "/work1/mzhang/zixianw4/torchtitan/torchtitan/train.py", line 768, in <module>
[rank6]:     main(Trainer)
[rank6]:   File "/work1/mzhang/zixianw4/torchtitan/torchtitan/train.py", line 755, in main
[rank6]:     trainer.train()
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 362, in wrapper
[rank6]:     return f(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/torchtitan/torchtitan/train.py", line 660, in train
[rank6]:     self.train_step(data_iterator)
[rank6]:   File "/work1/mzhang/zixianw4/torchtitan/torchtitan/train.py", line 560, in train_step
[rank6]:     loss = self.forward_backward_step(input_dict, labels)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/torchtitan/torchtitan/train.py", line 475, in forward_backward_step
[rank6]:     inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process(
[rank6]:                                                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/torchtitan/torchtitan/train.py", line 461, in post_dataloading_process
[rank6]:     extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
[rank6]:                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/torchtitan/torchtitan/models/deepseek_v3/model/model.py", line 467, in get_attention_masks
[rank6]:     return create_attention_mask(
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/torchtitan/torchtitan/models/attention.py", line 285, in create_attention_mask
[rank6]:     return _compiled_create_block_mask(*args, **kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 940, in compile_wrapper
[rank6]:     raise e.remove_dynamo_frames() from None  # see TORCHDYNAMO_VERBOSE=1
[rank6]:     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1019, in _compile_fx_inner
[rank6]:     raise InductorError(e, currentframe()).with_traceback(
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1003, in _compile_fx_inner
[rank6]:     mb_compiled_graph = fx_codegen_and_compile(
[rank6]:                         ^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1757, in fx_codegen_and_compile
[rank6]:     return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1537, in codegen_and_compile
[rank6]:     compiled_module = graph.compile_to_module()
[rank6]:                       ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/graph.py", line 2413, in compile_to_module
[rank6]:     return self._compile_to_module()
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/graph.py", line 2423, in _compile_to_module
[rank6]:     mod = self._compile_to_module_lines(wrapper_code)
[rank6]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/graph.py", line 2498, in _compile_to_module_lines
[rank6]:     mod = PyCodeCache.load_by_key_path(
[rank6]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3674, in load_by_key_path
[rank6]:     mod = _reload_python_module(key, path, set_sys_modules=in_toplevel)
[rank6]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 35, in _reload_python_module
[rank6]:     exec(code, mod.__dict__, mod.__dict__)
[rank6]:   File "/tmp/torchinductor_zixianw4/hn/chn3wnnc3fohn5f7kaxznnwnjhvbhsjkprpugapu7s4ovbc3wxes.py", line 75, in <module>
[rank6]:     triton_red_fused_arange_bitwise_and_eq_ge_index_permute_sum_view_0 = async_compile.triton('triton_red_fused_arange_bitwise_and_eq_ge_index_permute_sum_view_0', '''
[rank6]:                                                                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 477, in triton
[rank6]:     kernel.precompile(
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 441, in precompile
[rank6]:     self._precompile_worker()
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 463, in _precompile_worker
[rank6]:     compile_results.append(self._precompile_config(c))
[rank6]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 808, in _precompile_config
[rank6]:     binary = triton.compile(*compile_args, **compile_kwargs)
[rank6]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/triton/compiler/compiler.py", line 267, in compile
[rank6]:     res = CompiledKernel(src, metadata_group, hash)
[rank6]:           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/triton/compiler/compiler.py", line 423, in __init__
[rank6]:     self.asm = AsmDict({
[rank6]:                        ^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/site-packages/triton/compiler/compiler.py", line 424, in <dictcomp>
[rank6]:     file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
[rank6]:                      ^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/pathlib.py", line 1050, in read_bytes
[rank6]:     with self.open(mode='rb') as f:
[rank6]:          ^^^^^^^^^^^^^^^^^^^^
[rank6]:   File "/work1/mzhang/zixianw4/miniconda3/envs/titan/lib/python3.11/pathlib.py", line 1044, in open
[rank6]:     return io.open(self, mode, buffering, encoding, errors, newline)
[rank6]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank6]: torch._inductor.exc.InductorError: FileNotFoundError: [Errno 2] No such file or directory: '/home1/zixianw4/triton_cache_255441/E5DAZZK6P4ZXQ3VKJPNIEWZ7P4HLUH2VX7O2QO7USGHM5XQCLM6A/triton_red_fused_arange_bitwise_and_eq_ge_index_permute_sum_view_0.hsaco'

[rank6]: Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Versions

I am using 8 AMD MI250 gpus and 1 node.

env is installing from
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/rocm7.1

Here is my slurm:

#!/bin/bash

#SBATCH -J torchtitan_multi_node
#SBATCH -o /work1/mzhang/zixianw4/torchtitan/torchtitan-%j.o
#SBATCH -e /work1/mzhang/zixianw4/torchtitan/torchtitan-%j.e
#SBATCH -t 00:10:00
#SBATCH -p mi2508x
#SBATCH -N 1
#SBATCH -n 8
#######SBATCH --exclude=k005-004,k005-005,k003-001,k004-006,k004-003

# --- 1. Robust Environment Setup ---
# Source conda.sh to get the 'conda' command
source /work1/mzhang/zixianw4/miniconda3/etc/profile.d/conda.sh

# Activate the environment
conda activate titan

# FIX: Force the PATH to ensure we use the 'titan' environment binaries
# even if 'conda activate' flaked out in the batch shell.
export TITAN_HOME=/work1/mzhang/zixianw4/miniconda3/envs/titan
export PATH=$TITAN_HOME/bin:$PATH


# --- 2. Diagnostic Check (Check Logs if this fails) ---
echo "------------------------------------------------------"
echo "DEBUG: Checking Python Environment"
echo "which python:   $(which python)"
echo "python version: $(python --version)"
python -c "import torch; print(f'Torch Version: {torch.__version__}'); print(f'ROCm/CUDA: {torch.cuda.is_available()}')"
echo "------------------------------------------------------"



export OMP_NUM_THREADS=2
export CUDA_DEVICE_MAX_CONNECTIONS=1


# --- 2. Network Discovery Logic (Specific to this Cluster) ---
# We use the logic from your template to find the interface on the 10.0.200.x subnet
IB_SUBNET="10.0.200."
export MASTER_PORT=29500

nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
nodes_array=($nodes)
head_node=${nodes_array[0]}

# SSH into the head node to find the interface and IP on the specific IB subnet
ib_info=$(ssh $head_node "ip -4 a | grep 'inet ${IB_SUBNET}'")

if [[ -z "${ib_info}" ]]; then
    echo "ERROR: Could not find an interface on the IB_SUBNET (${IB_SUBNET}) on master node ${head_node}."
    exit 1
fi

# Extract IP and Interface name
head_node_ip=$(echo $ib_info | awk '{print $2}' | cut -d/ -f1)
ib_interface=$(echo $ib_info | awk '{print $NF}') # Gets the last field (interface name)

echo "Head Node: $head_node"
echo "Head Node IP: $head_node_ip"
echo "Interface: $ib_interface"

# --- 3. Network Configuration ---
export NCCL_SOCKET_IFNAME=${ib_interface}
export LOGLEVEL=INFO
export NCCL_DEBUG=WARN
export PYTHONFAULTHANDLER=1
# export NCCL_IB_DISABLE=0 
# export NCCL_P2P_DISABLE=0

# Clean up AWS/EFA/Frontier specific paths (Removed EFA LD_LIBRARY_PATH)
export CUDA_LAUNCH_BLOCKING=0
export NCCL_BUFFSIZE=2097152

# --- 4. Execution ---
# CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml"}

torchrun --nnodes 1 --nproc_per_node 8 \
    --rdzv_id 101 \
    --rdzv_backend c10d \
    --rdzv_endpoint "$head_node_ip:29500" \
    -m torchtitan.train \
    --job.config_file ${CONFIG_FILE} "$@"

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions