Skip to content

Commit b2206e1

Browse files
Update benchmarks.yml
1 parent dee2556 commit b2206e1

File tree

1 file changed

+16
-21
lines changed

1 file changed

+16
-21
lines changed

.github/workflows/benchmarks.yml

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,23 @@ jobs:
8585
- name: Create gemm_006f564ad71b327343de5f090e801883.hlo
8686
working-directory: xla
8787
run: |
88-
cat << EOF > gemm_006f564ad71b327343de5f090e801883.hlo
89-
Code panel - press c to focus line 1.
90-
HloModule gemm_fusion_dot.166, entry_computation_layout={(bf16[8,12,2048,2048]{3,2,1,0}, bf16[16384,128]{1,0}, bf16[128]{0})->bf16[8,12,2048,128]{3,2,1,0}}
91-
92-
%gemm_fusion_dot.166_computation.clone (parameter_0.167: bf16[8,12,2048,2048], parameter_1.167: bf16[16384,128], parameter_2.18: bf16[128]) -> bf16[8,12,2048,128] {
93-
%parameter_0.167 = bf16[8,12,2048,2048]{3,2,1,0} parameter(0)
94-
%bitcast.22615 = bf16[8,24576,2048]{2,1,0} bitcast(bf16[8,12,2048,2048]{3,2,1,0} %parameter_0.167)
95-
%parameter_1.167 = bf16[16384,128]{1,0} parameter(1)
96-
%parameter_2.18 = bf16[128]{0} parameter(2)
97-
%broadcast.9073 = bf16[16384,128]{1,0} broadcast(bf16[128]{0} %parameter_2.18), dimensions={1}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/value/mul" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=228}
98-
%multiply.7656 = bf16[16384,128]{1,0} multiply(bf16[16384,128]{1,0} %parameter_1.167, bf16[16384,128]{1,0} %broadcast.9073), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/value/mul" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=228}
99-
%bitcast.22616 = bf16[8,2048,128]{2,1,0} bitcast(bf16[16384,128]{1,0} %multiply.7656)
100-
%dot.1454 = bf16[8,24576,128]{2,1,0} dot(bf16[8,24576,2048]{2,1,0} %bitcast.22615, bf16[8,2048,128]{2,1,0} %bitcast.22616), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/self_attention._dot_atten/pv_einsum/BNTS,BSH->BNTH/dot_general[dimension_numbers=(((3,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/base_ops.py" source_line=28}
101-
ROOT %bitcast.22617 = bf16[8,12,2048,128]{3,2,1,0} bitcast(bf16[8,24576,128]{2,1,0} %dot.1454), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/self_attention._dot_atten/transpose[permutation=(0, 2, 1, 3)]" source_file="third_party/py/praxis/layers/multi_query_attention.py" source_line=454}
88+
cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
89+
HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
90+
91+
%gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
92+
%parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
93+
%bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
94+
%parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
95+
%bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
96+
%convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
97+
%dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
98+
ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
10299
}
103100
104-
ENTRY %entry_computation (convert.8139: bf16[8,12,2048,2048], gemm_fusion_dot.163: bf16[16384,128], Arg_23.24: bf16[128]) -> bf16[8,12,2048,128] {
105-
%convert.8139 = bf16[8,12,2048,2048]{3,2,1,0} parameter(0)
106-
%gemm_fusion_dot.163 = bf16[16384,128]{1,0} parameter(1)
107-
%Arg_23.24 = bf16[128]{0} parameter(2)
108-
ROOT %micro_kernel = bf16[8,12,2048,128]{3,2,1,0} fusion(bf16[8,12,2048,2048]{3,2,1,0} %convert.8139, bf16[16384,128]{1,0} %gemm_fusion_dot.163, bf16[128]{0} %Arg_23.24), kind=kCustom, calls=%gemm_fusion_dot.166_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/self_attention._dot_atten/pv_einsum/BNTS,BSH->BNTH/dot_general[dimension_numbers=(((3,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/base_ops.py" source_line=28}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
109-
} ROOT fusion = s8[1,4,192,384,32] fusion(param_0), kind=kInput, calls=fusion
101+
ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
102+
%multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
103+
%Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
104+
ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
110105
}
111106
EOF
112107
@@ -117,7 +112,7 @@ jobs:
117112

118113
- name: Run specific HLO file
119114
working-directory: xla
120-
run: ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_006f564ad71b327343de5f090e801883.hlo &> results/gemm_006f564ad71b327343de5f090e801883.hlo.log
115+
run: ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
121116

122117
# - name: Run HLO Module Benchmarks with GPU in xla/tests/fuzz
123118
# working-directory: xla

0 commit comments

Comments
 (0)