@@ -82,27 +82,42 @@ jobs:
8282 working-directory : xla
8383 run : bazel build -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main
8484
85- - name : Create b284431534_transpose_convert_f32_s8 .hlo
85+ - name : Create gemm_006f564ad71b327343de5f090e801883 .hlo
8686 working-directory : xla
8787 run : |
88- cat << EOF > b284431534_transpose_convert_f32_s8.hlo
89- HloModule test, entry_computation_layout={(f32[1,4,32,192,384]{4,3,2,1,0})->s8[1,4,192,384,32]{4,3,2,1,0}}
90-
91- fusion {
92- param_0 = f32[1,4,32,192,384] parameter(0)
93- transpose = f32[1,4,192,384,32] transpose(param_0), dimensions={0,1,3,4,2}
94- ROOT convert = s8[1,4,192,384,32] convert(transpose)
88+ cat << EOF > gemm_006f564ad71b327343de5f090e801883.hlo
89+ Code panel - press c to focus line 1.
90+ HloModule gemm_fusion_dot.166, entry_computation_layout={(bf16[8,12,2048,2048]{3,2,1,0}, bf16[16384,128]{1,0}, bf16[128]{0})->bf16[8,12,2048,128]{3,2,1,0}}
91+
92+ %gemm_fusion_dot.166_computation.clone (parameter_0.167: bf16[8,12,2048,2048], parameter_1.167: bf16[16384,128], parameter_2.18: bf16[128]) -> bf16[8,12,2048,128] {
93+ %parameter_0.167 = bf16[8,12,2048,2048]{3,2,1,0} parameter(0)
94+ %bitcast.22615 = bf16[8,24576,2048]{2,1,0} bitcast(bf16[8,12,2048,2048]{3,2,1,0} %parameter_0.167)
95+ %parameter_1.167 = bf16[16384,128]{1,0} parameter(1)
96+ %parameter_2.18 = bf16[128]{0} parameter(2)
97+ %broadcast.9073 = bf16[16384,128]{1,0} broadcast(bf16[128]{0} %parameter_2.18), dimensions={1}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/value/mul" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=228}
98+ %multiply.7656 = bf16[16384,128]{1,0} multiply(bf16[16384,128]{1,0} %parameter_1.167, bf16[16384,128]{1,0} %broadcast.9073), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/value/mul" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=228}
99+ %bitcast.22616 = bf16[8,2048,128]{2,1,0} bitcast(bf16[16384,128]{1,0} %multiply.7656)
100+ %dot.1454 = bf16[8,24576,128]{2,1,0} dot(bf16[8,24576,2048]{2,1,0} %bitcast.22615, bf16[8,2048,128]{2,1,0} %bitcast.22616), lhs_batch_dims={0}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/self_attention._dot_atten/pv_einsum/BNTS,BSH->BNTH/dot_general[dimension_numbers=(((3,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/base_ops.py" source_line=28}
101+ ROOT %bitcast.22617 = bf16[8,12,2048,128]{3,2,1,0} bitcast(bf16[8,24576,128]{2,1,0} %dot.1454), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/self_attention._dot_atten/transpose[permutation=(0, 2, 1, 3)]" source_file="third_party/py/praxis/layers/multi_query_attention.py" source_line=454}
95102 }
96-
97- ENTRY main {
98- param_0 = f32[1,4,32,192,384] parameter(0)
99- ROOT fusion = s8[1,4,192,384,32] fusion(param_0), kind=kInput, calls=fusion
103+
104+ ENTRY %entry_computation (convert.8139: bf16[8,12,2048,2048], gemm_fusion_dot.163: bf16[16384,128], Arg_23.24: bf16[128]) -> bf16[8,12,2048,128] {
105+ %convert.8139 = bf16[8,12,2048,2048]{3,2,1,0} parameter(0)
106+ %gemm_fusion_dot.163 = bf16[16384,128]{1,0} parameter(1)
107+ %Arg_23.24 = bf16[128]{0} parameter(2)
108+ ROOT %micro_kernel = bf16[8,12,2048,128]{3,2,1,0} fusion(bf16[8,12,2048,2048]{3,2,1,0} %convert.8139, bf16[16384,128]{1,0} %gemm_fusion_dot.163, bf16[128]{0} %Arg_23.24), kind=kCustom, calls=%gemm_fusion_dot.166_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/self_attention._dot_atten/pv_einsum/BNTS,BSH->BNTH/dot_general[dimension_numbers=(((3,), (1,)), ((0,), (0,))) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/base_ops.py" source_line=28}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
109+ } ROOT fusion = s8[1,4,192,384,32] fusion(param_0), kind=kInput, calls=fusion
100110 }
101111 EOF
102112
113+ - name : Wait For Connection
114+ uses : google-ml-infra/actions/ci_connection@main
115+ with :
116+ halt-dispatch-input : ${{ inputs.halt-for-connection }}
117+
103118 - name : Run specific HLO file
104119 working-directory : xla
105- run : ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning b284431534_transpose_convert_f32_s8 .hlo &> results/b284431534_transpose_convert_f32_s8 .hlo.log
120+ run : ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_006f564ad71b327343de5f090e801883 .hlo &> results/gemm_006f564ad71b327343de5f090e801883 .hlo.log
106121
107122 # - name: Run HLO Module Benchmarks with GPU in xla/tests/fuzz
108123 # working-directory: xla
0 commit comments