Skip to content

Commit a94f159

Browse files
Update benchmarks.yml
1 parent 39567e4 commit a94f159

File tree

1 file changed

+62
-57
lines changed

1 file changed

+62
-57
lines changed

.github/workflows/benchmarks.yml

Lines changed: 62 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,16 @@ jobs:
2323
options: --gpus all --privileged # Might need privileged mode, use with caution
2424

2525
steps:
26-
# - name: Checkout XLA
27-
# uses: actions/checkout@v3
28-
# with:
29-
# repository: openxla/xla # Replace with your fork if needed
30-
# path: xla
31-
- name: Checkout repository
26+
- name: Checkout XLA
3227
uses: actions/checkout@v3
3328
with:
34-
repository: juliagmt-google/xla
29+
repository: openxla/xla # Replace with your fork if needed
3530
path: xla
31+
# - name: Checkout repository
32+
# uses: actions/checkout@v3
33+
# with:
34+
# repository: juliagmt-google/xla
35+
# path: xla
3636

3737
- name: Create results directory
3838
working-directory: xla
@@ -59,61 +59,66 @@ jobs:
5959
working-directory: xla
6060
run: bazel build -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main
6161

62+
# - name: Wait For Connection
63+
# uses: google-ml-infra/actions/ci_connection@main
64+
# with:
65+
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
66+
67+
- name: Create gemm_006f564ad71b327343de5f090e801883.hlo
68+
working-directory: xla
69+
run: |
70+
cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
71+
HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
72+
73+
%gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
74+
%parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
75+
%bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
76+
%parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
77+
%bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
78+
%convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
79+
%dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
80+
ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
81+
}
82+
83+
ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
84+
%multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
85+
%Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
86+
ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
87+
}
88+
EOF
89+
90+
# - name: Wait For Connection
91+
# uses: google-ml-infra/actions/ci_connection@main
92+
# with:
93+
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
94+
95+
- name: Run specific HLO file
96+
working-directory: xla
97+
run: |
98+
./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.log
99+
100+
- name: Run specific HLO file v2
101+
working-directory: xla
102+
run: |
103+
./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo --dump_output_literal_to=results/gemm_00881937d6d49056045c3325a12b108b_v2.log
104+
62105
- name: Wait For Connection
63106
uses: google-ml-infra/actions/ci_connection@main
64107
with:
65108
halt-dispatch-input: ${{ inputs.halt-for-connection }}
109+
- name: Download parse_xla_logs.py
110+
working-directory: xla
111+
run: wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
66112

67-
# - name: Create gemm_006f564ad71b327343de5f090e801883.hlo
68-
# working-directory: xla
69-
# run: |
70-
# cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
71-
# HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
72-
73-
# %gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
74-
# %parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
75-
# %bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
76-
# %parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
77-
# %bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
78-
# %convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
79-
# %dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
80-
# ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
81-
# }
82-
83-
# ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
84-
# %multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
85-
# %Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
86-
# ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
87-
# }
88-
# EOF
89-
90-
# # - name: Wait For Connection
91-
# # uses: google-ml-infra/actions/ci_connection@main
92-
# # with:
93-
# # halt-dispatch-input: ${{ inputs.halt-for-connection }}
94-
95-
# - name: Run specific HLO file
96-
# working-directory: xla
97-
# run: |
98-
# ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
99-
100-
# - name: Wait For Connection
101-
# uses: google-ml-infra/actions/ci_connection@main
102-
# with:
103-
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
104-
# - name: Download parse_xla_logs.py
105-
# working-directory: xla
106-
# run: wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
107-
108-
# - name: Parse XLA logs
109-
# working-directory: xla
110-
# run: python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
111-
112-
# - name: Upload Results
113-
# uses: actions/upload-artifact@v4
114-
# with:
115-
# name: gpu-xla-benchmarks
116-
# path: xla/results
113+
- name: Parse XLA logs
114+
working-directory: xla
115+
run: python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
116+
117+
- name: Upload Results
118+
uses: actions/upload-artifact@v4
119+
with:
120+
name: gpu-xla-benchmarks
121+
path: xla/results
117122
# # jax-build-and-test:
118123
# # runs-on: linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
119124
# # container:

0 commit comments

Comments
 (0)