Skip to content

Commit 39567e4

Browse files
Update benchmarks.yml
1 parent d67626c commit 39567e4

File tree

1 file changed

+102
-92
lines changed

1 file changed

+102
-92
lines changed

.github/workflows/benchmarks.yml

Lines changed: 102 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,15 @@ jobs:
2323
options: --gpus all --privileged # Might need privileged mode, use with caution
2424

2525
steps:
26-
- name: Checkout XLA
26+
# - name: Checkout XLA
27+
# uses: actions/checkout@v3
28+
# with:
29+
# repository: openxla/xla # Replace with your fork if needed
30+
# path: xla
31+
- name: Checkout repository
2732
uses: actions/checkout@v3
2833
with:
29-
repository: openxla/xla # Replace with your fork if needed
34+
repository: juliagmt-google/xla
3035
path: xla
3136

3237
- name: Create results directory
@@ -53,103 +58,108 @@ jobs:
5358
- name: Build hlo_runner_main
5459
working-directory: xla
5560
run: bazel build -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main
56-
57-
- name: Create gemm_006f564ad71b327343de5f090e801883.hlo
58-
working-directory: xla
59-
run: |
60-
cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
61-
HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
62-
63-
%gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
64-
%parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
65-
%bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
66-
%parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
67-
%bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
68-
%convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
69-
%dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
70-
ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
71-
}
72-
73-
ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
74-
%multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
75-
%Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
76-
ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
77-
}
78-
EOF
79-
80-
# - name: Wait For Connection
81-
# uses: google-ml-infra/actions/ci_connection@main
82-
# with:
83-
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
84-
85-
- name: Run specific HLO file
86-
working-directory: xla
87-
run: |
88-
./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
89-
61+
9062
- name: Wait For Connection
9163
uses: google-ml-infra/actions/ci_connection@main
9264
with:
9365
halt-dispatch-input: ${{ inputs.halt-for-connection }}
94-
- name: Download parse_xla_logs.py
95-
working-directory: xla
96-
run: wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
9766

98-
- name: Parse XLA logs
99-
working-directory: xla
100-
run: python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
67+
# - name: Create gemm_006f564ad71b327343de5f090e801883.hlo
68+
# working-directory: xla
69+
# run: |
70+
# cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
71+
# HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
72+
73+
# %gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
74+
# %parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
75+
# %bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
76+
# %parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
77+
# %bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
78+
# %convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
79+
# %dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
80+
# ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
81+
# }
82+
83+
# ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
84+
# %multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
85+
# %Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
86+
# ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
87+
# }
88+
# EOF
89+
90+
# # - name: Wait For Connection
91+
# # uses: google-ml-infra/actions/ci_connection@main
92+
# # with:
93+
# # halt-dispatch-input: ${{ inputs.halt-for-connection }}
94+
95+
# - name: Run specific HLO file
96+
# working-directory: xla
97+
# run: |
98+
# ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
10199

102-
- name: Upload Results
103-
uses: actions/upload-artifact@v4
104-
with:
105-
name: gpu-xla-benchmarks
106-
path: xla/results
107-
# jax-build-and-test:
108-
# runs-on: linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
109-
# container:
110-
# image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
111-
112-
# env:
113-
# JAXCI_HERMETIC_PYTHON_VERSION: 3.11
114-
115-
# steps:
116-
# - name: Checkout JAX Fork
117-
# uses: actions/checkout@v3
100+
# - name: Wait For Connection
101+
# uses: google-ml-infra/actions/ci_connection@main
118102
# with:
119-
# repository: 'google-ml-infra/jax-fork'
120-
# path: jax-fork
103+
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
104+
# - name: Download parse_xla_logs.py
105+
# working-directory: xla
106+
# run: wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
121107

122-
# - name: Install JAX Dependencies
123-
# working-directory: jax-fork
124-
# run: |
125-
# python -m pip install --upgrade pip
126-
# pip install pytest
127-
# pip install absl-py
128-
# pip install "jax[cuda12_pip]" # Adjust CUDA version if needed
129-
# pip install google-benchmark
130-
# - name: Run JAX Multiprocess GPU Test
131-
# working-directory: jax-fork
132-
# continue-on-error: true
133-
# run: python -m pytest tests/multiprocess_gpu_test.py
108+
# - name: Parse XLA logs
109+
# working-directory: xla
110+
# run: python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
111+
112+
# - name: Upload Results
113+
# uses: actions/upload-artifact@v4
114+
# with:
115+
# name: gpu-xla-benchmarks
116+
# path: xla/results
117+
# # jax-build-and-test:
118+
# # runs-on: linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
119+
# # container:
120+
# # image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
121+
122+
# # env:
123+
# # JAXCI_HERMETIC_PYTHON_VERSION: 3.11
124+
125+
# # steps:
126+
# # - name: Checkout JAX Fork
127+
# # uses: actions/checkout@v3
128+
# # with:
129+
# # repository: 'google-ml-infra/jax-fork'
130+
# # path: jax-fork
131+
132+
# # - name: Install JAX Dependencies
133+
# # working-directory: jax-fork
134+
# # run: |
135+
# # python -m pip install --upgrade pip
136+
# # pip install pytest
137+
# # pip install absl-py
138+
# # pip install "jax[cuda12_pip]" # Adjust CUDA version if needed
139+
# # pip install google-benchmark
140+
# # - name: Run JAX Multiprocess GPU Test
141+
# # working-directory: jax-fork
142+
# # continue-on-error: true
143+
# # run: python -m pytest tests/multiprocess_gpu_test.py
134144

135145

136-
# - name: Run HLO Module Benchmarks withg GPU in xla/tests/fuzz
137-
# working-directory: xla
138-
# continue-on-error: true
139-
# run: |
140-
# for file in xla/tests/fuzz/*.hlo; do
141-
# filename=$(basename "$file")
142-
# # Skip expected failed hlo files.
143-
# if [[ "$filename" == "rand_000060.hlo" || "$filename" == "rand_000067.hlo" || "$filename" == "rand_000072.hlo" ]]; then
144-
# echo "Skipping benchmark on $file"
145-
# continue
146-
# fi
147-
# echo "Running benchmark on $file" &> results/"$filename".log
148-
# ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning "$file" &> results/"$filename".log
149-
# done
150-
151-
# - name: Upload Results
152-
# uses: actions/upload-artifact@v4
153-
# with:
154-
# name: gpu-xla-benchmarks
155-
# path: xla/results
146+
# # - name: Run HLO Module Benchmarks withg GPU in xla/tests/fuzz
147+
# # working-directory: xla
148+
# # continue-on-error: true
149+
# # run: |
150+
# # for file in xla/tests/fuzz/*.hlo; do
151+
# # filename=$(basename "$file")
152+
# # # Skip expected failed hlo files.
153+
# # if [[ "$filename" == "rand_000060.hlo" || "$filename" == "rand_000067.hlo" || "$filename" == "rand_000072.hlo" ]]; then
154+
# # echo "Skipping benchmark on $file"
155+
# # continue
156+
# # fi
157+
# # echo "Running benchmark on $file" &> results/"$filename".log
158+
# # ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning "$file" &> results/"$filename".log
159+
# # done
160+
161+
# # - name: Upload Results
162+
# # uses: actions/upload-artifact@v4
163+
# # with:
164+
# # name: gpu-xla-benchmarks
165+
# # path: xla/results

0 commit comments

Comments
 (0)