Skip to content

Commit 3e2f050

Browse files
Update benchmarks.yml
1 parent 53c5b6e commit 3e2f050

File tree

1 file changed

+31
-24
lines changed

1 file changed

+31
-24
lines changed

.github/workflows/benchmarks.yml

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,6 @@ jobs:
5454
- name: Check TF_CPP_MAX_VLOG_LEVEL
5555
working-directory: xla
5656
run: echo "$TF_CPP_MAX_VLOG_LEVEL"
57-
# - name: Wait For Connection
58-
# uses: google-ml-infra/actions/ci_connection@main
59-
# with:
60-
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
6157

6258
- name: Build hlo_runner_main
6359
working-directory: xla
@@ -68,39 +64,50 @@ jobs:
6864
# with:
6965
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
7066

71-
- name: Create gpu_hlo_backend.hlo
67+
- name: Create gemm_006f564ad71b327343de5f090e801883.hlo
7268
working-directory: xla
7369
run: |
74-
cat << EOF > gpu_hlo_backend.hlo
75-
HloModule module
76-
// CHECK: is_scheduled=true
70+
cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
71+
HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
72+
%gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
73+
%parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
74+
%bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
75+
%parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
76+
%bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
77+
%convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
78+
%dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
79+
ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
80+
}
7781
78-
ENTRY computation {
79-
p = f32[5000,6000]{1,0} parameter(0)
80-
e = f32[5000,6000]{1,0} sqrt(p)
81-
c = f32[6000,5000] transpose(p), dimensions={1,0}
82-
r = f32[300,20,5000] reshape(c)
83-
ROOT out = (f32[5000,6000], f32[300,20,5000]) tuple(e,r)
82+
ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
83+
%multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
84+
%Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
85+
ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
8486
}
8587
EOF
86-
87-
- name: Run specific HLO file
88+
89+
- name: Wait For Connection
90+
uses: google-ml-infra/actions/ci_connection@main
91+
with:
92+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
93+
94+
- name: Run an HLO file
8895
working-directory: xla
8996
run: |
90-
nvidia-smi --query-gpu=utilization.gpu --format=csv -l 1 > results/gpu_utilization_v2.log & ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gpu_hlo_backend.hlo &> results/gpu_hlo_backend.log
91-
92-
# - name: Wait For Connection
93-
# uses: google-ml-infra/actions/ci_connection@main
94-
# with:
95-
# halt-dispatch-input: ${{ inputs.halt-for-connection }}
96-
97+
./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --log_output=True --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
98+
99+
- name: Wait For Connection
100+
uses: google-ml-infra/actions/ci_connection@main
101+
with:
102+
halt-dispatch-input: ${{ inputs.halt-for-connection }}
103+
97104
- name: Download parse_xla_logs.py
98105
working-directory: xla
99106
run: wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
100107

101108
- name: Parse XLA logs
102109
working-directory: xla
103-
run: python parse_xla_logs.py results/gpu_hlo_backend.log
110+
run: python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
104111

105112
- name: Upload Results
106113
uses: actions/upload-artifact@v4

0 commit comments

Comments
 (0)