@@ -23,16 +23,16 @@ jobs:
2323 options : --gpus all --privileged # Might need privileged mode, use with caution
2424
2525 steps :
26- # - name: Checkout XLA
27- # uses: actions/checkout@v3
28- # with:
29- # repository: openxla/xla # Replace with your fork if needed
30- # path: xla
31- - name : Checkout repository
26+ - name : Checkout XLA
3227 uses : actions/checkout@v3
3328 with :
34- repository : juliagmt-google /xla
29+ repository : openxla /xla # Replace with your fork if needed
3530 path : xla
31+ # - name: Checkout repository
32+ # uses: actions/checkout@v3
33+ # with:
34+ # repository: juliagmt-google/xla
35+ # path: xla
3636
3737 - name : Create results directory
3838 working-directory : xla
@@ -59,61 +59,66 @@ jobs:
5959 working-directory : xla
6060 run : bazel build -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main
6161
62+ # - name: Wait For Connection
63+ # uses: google-ml-infra/actions/ci_connection@main
64+ # with:
65+ # halt-dispatch-input: ${{ inputs.halt-for-connection }}
66+
67+ - name : Create gemm_006f564ad71b327343de5f090e801883.hlo
68+ working-directory : xla
69+ run : |
70+ cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
71+ HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
72+
73+ %gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
74+ %parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
75+ %bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
76+ %parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
77+ %bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
78+ %convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
79+ %dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
80+ ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
81+ }
82+
83+ ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
84+ %multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
85+ %Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
86+ ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
87+ }
88+ EOF
89+
90+ # - name: Wait For Connection
91+ # uses: google-ml-infra/actions/ci_connection@main
92+ # with:
93+ # halt-dispatch-input: ${{ inputs.halt-for-connection }}
94+
95+ - name : Run specific HLO file
96+ working-directory : xla
97+ run : |
98+ ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.log
99+
100+ - name : Run specific HLO file v2
101+ working-directory : xla
102+ run : |
103+ ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo --dump_output_literal_to=results/gemm_00881937d6d49056045c3325a12b108b_v2.log
104+
62105 - name : Wait For Connection
63106 uses : google-ml-infra/actions/ci_connection@main
64107 with :
65108 halt-dispatch-input : ${{ inputs.halt-for-connection }}
109+ - name : Download parse_xla_logs.py
110+ working-directory : xla
111+ run : wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
66112
67- # - name: Create gemm_006f564ad71b327343de5f090e801883.hlo
68- # working-directory: xla
69- # run: |
70- # cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
71- # HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
72-
73- # %gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
74- # %parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
75- # %bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
76- # %parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
77- # %bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
78- # %convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
79- # %dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
80- # ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
81- # }
82-
83- # ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
84- # %multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
85- # %Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
86- # ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
87- # }
88- # EOF
89-
90- # # - name: Wait For Connection
91- # # uses: google-ml-infra/actions/ci_connection@main
92- # # with:
93- # # halt-dispatch-input: ${{ inputs.halt-for-connection }}
94-
95- # - name: Run specific HLO file
96- # working-directory: xla
97- # run: |
98- # ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
99-
100- # - name: Wait For Connection
101- # uses: google-ml-infra/actions/ci_connection@main
102- # with:
103- # halt-dispatch-input: ${{ inputs.halt-for-connection }}
104- # - name: Download parse_xla_logs.py
105- # working-directory: xla
106- # run: wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
107-
108- # - name: Parse XLA logs
109- # working-directory: xla
110- # run: python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
111-
112- # - name: Upload Results
113- # uses: actions/upload-artifact@v4
114- # with:
115- # name: gpu-xla-benchmarks
116- # path: xla/results
113+ - name : Parse XLA logs
114+ working-directory : xla
115+ run : python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
116+
117+ - name : Upload Results
118+ uses : actions/upload-artifact@v4
119+ with :
120+ name : gpu-xla-benchmarks
121+ path : xla/results
117122 # # jax-build-and-test:
118123 # # runs-on: linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
119124 # # container:
0 commit comments