@@ -23,10 +23,15 @@ jobs:
2323 options : --gpus all --privileged # Might need privileged mode, use with caution
2424
2525 steps :
26- - name : Checkout XLA
26+ # - name: Checkout XLA
27+ # uses: actions/checkout@v3
28+ # with:
29+ # repository: openxla/xla # Replace with your fork if needed
30+ # path: xla
31+ - name : Checkout repository
2732 uses : actions/checkout@v3
2833 with :
29- repository : openxla /xla # Replace with your fork if needed
34+ repository : juliagmt-google /xla
3035 path : xla
3136
3237 - name : Create results directory
@@ -53,103 +58,108 @@ jobs:
5358 - name : Build hlo_runner_main
5459 working-directory : xla
5560 run : bazel build -c opt --config=cuda --dynamic_mode=off //xla/tools/multihost_hlo_runner:hlo_runner_main
56-
57- - name : Create gemm_006f564ad71b327343de5f090e801883.hlo
58- working-directory : xla
59- run : |
60- cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
61- HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
62-
63- %gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
64- %parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
65- %bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
66- %parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
67- %bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
68- %convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
69- %dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
70- ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
71- }
72-
73- ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
74- %multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
75- %Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
76- ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
77- }
78- EOF
79-
80- # - name: Wait For Connection
81- # uses: google-ml-infra/actions/ci_connection@main
82- # with:
83- # halt-dispatch-input: ${{ inputs.halt-for-connection }}
84-
85- - name : Run specific HLO file
86- working-directory : xla
87- run : |
88- ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
89-
61+
9062 - name : Wait For Connection
9163 uses : google-ml-infra/actions/ci_connection@main
9264 with :
9365 halt-dispatch-input : ${{ inputs.halt-for-connection }}
94- - name : Download parse_xla_logs.py
95- working-directory : xla
96- run : wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
9766
98- - name : Parse XLA logs
99- working-directory : xla
100- run : python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
67+ # - name: Create gemm_006f564ad71b327343de5f090e801883.hlo
68+ # working-directory: xla
69+ # run: |
70+ # cat << EOF > gemm_00881937d6d49056045c3325a12b108b.hlo
71+ # HloModule gemm_fusion_dot.542, entry_computation_layout={(bf16[1,8192,3072]{2,1,0}, s8[3072,6,512]{2,1,0})->bf16[1,8192,6,512]{3,2,1,0}}
72+
73+ # %gemm_fusion_dot.542_computation.clone (parameter_0.543: bf16[1,8192,3072], parameter_1.543: s8[3072,6,512]) -> bf16[1,8192,6,512] {
74+ # %parameter_0.543 = bf16[1,8192,3072]{2,1,0} parameter(0)
75+ # %bitcast.69925 = bf16[8192,3072]{1,0} bitcast(bf16[1,8192,3072]{2,1,0} %parameter_0.543)
76+ # %parameter_1.543 = s8[3072,6,512]{2,1,0} parameter(1)
77+ # %bitcast.69926 = s8[3072,3072]{1,0} bitcast(s8[3072,6,512]{2,1,0} %parameter_1.543)
78+ # %convert.18528 = bf16[3072,3072]{1,0} convert(s8[3072,3072]{1,0} %bitcast.69926), metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/convert_element_type[new_dtype=bfloat16 weak_type=False]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
79+ # %dot.4949 = bf16[8192,3072]{1,0} dot(bf16[8192,3072]{1,0} %bitcast.69925, bf16[3072,3072]{1,0} %convert.18528), lhs_contracting_dims={1}, rhs_contracting_dims={0}, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}
80+ # ROOT %bitcast.69927 = bf16[1,8192,6,512]{3,2,1,0} bitcast(bf16[8192,3072]{1,0} %dot.4949)
81+ # }
82+
83+ # ENTRY %entry_computation (multiply.28104: bf16[1,8192,3072], Arg_51.52: s8[3072,6,512]) -> bf16[1,8192,6,512] {
84+ # %multiply.28104 = bf16[1,8192,3072]{2,1,0} parameter(0)
85+ # %Arg_51.52 = s8[3072,6,512]{2,1,0} parameter(1)
86+ # ROOT %micro_kernel = bf16[1,8192,6,512]{3,2,1,0} fusion(bf16[1,8192,3072]{2,1,0} %multiply.28104, s8[3072,6,512]{2,1,0} %Arg_51.52), kind=kCustom, calls=%gemm_fusion_dot.542_computation.clone, metadata={op_name="pjit(_wrapped_fn)/jit(main)/tarzan_lm.apply/tarzan_lm.decode_with_params/lm/transformer/x_layers_0/self_attention/query/query.quantized_einsum/ABD,DNH->ABNH/dot_general[dimension_numbers=(((2,), (0,)), ((), ())) precision=None preferred_element_type=None]" source_file="third_party/py/praxis/layers/quantization/operations.py" source_line=220}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm"},"force_earliest_schedule":false}
87+ # }
88+ # EOF
89+
90+ # # - name: Wait For Connection
91+ # # uses: google-ml-infra/actions/ci_connection@main
92+ # # with:
93+ # # halt-dispatch-input: ${{ inputs.halt-for-connection }}
94+
95+ # - name: Run specific HLO file
96+ # working-directory: xla
97+ # run: |
98+ # ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning gemm_00881937d6d49056045c3325a12b108b.hlo &> results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
10199
102- - name : Upload Results
103- uses : actions/upload-artifact@v4
104- with :
105- name : gpu-xla-benchmarks
106- path : xla/results
107- # jax-build-and-test:
108- # runs-on: linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
109- # container:
110- # image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
111-
112- # env:
113- # JAXCI_HERMETIC_PYTHON_VERSION: 3.11
114-
115- # steps:
116- # - name: Checkout JAX Fork
117- # uses: actions/checkout@v3
100+ # - name: Wait For Connection
101+ # uses: google-ml-infra/actions/ci_connection@main
118102 # with:
119- # repository: 'google-ml-infra/jax-fork'
120- # path: jax-fork
103+ # halt-dispatch-input: ${{ inputs.halt-for-connection }}
104+ # - name: Download parse_xla_logs.py
105+ # working-directory: xla
106+ # run: wget https://raw.githubusercontent.com/juliagmt-google/xla/main/.github/workflows/parse_xla_logs.py
121107
122- # - name: Install JAX Dependencies
123- # working-directory: jax-fork
124- # run: |
125- # python -m pip install --upgrade pip
126- # pip install pytest
127- # pip install absl-py
128- # pip install "jax[cuda12_pip]" # Adjust CUDA version if needed
129- # pip install google-benchmark
130- # - name: Run JAX Multiprocess GPU Test
131- # working-directory: jax-fork
132- # continue-on-error: true
133- # run: python -m pytest tests/multiprocess_gpu_test.py
108+ # - name: Parse XLA logs
109+ # working-directory: xla
110+ # run: python parse_xla_logs.py results/gemm_00881937d6d49056045c3325a12b108b.hlo.log
111+
112+ # - name: Upload Results
113+ # uses: actions/upload-artifact@v4
114+ # with:
115+ # name: gpu-xla-benchmarks
116+ # path: xla/results
117+ # # jax-build-and-test:
118+ # # runs-on: linux-x86-g2-48-l4-4gpu # Use a GPU-enabled runner
119+ # # container:
120+ # # image: "gcr.io/tensorflow-testing/nosla-cuda12.3-cudnn9.1-ubuntu20.04-manylinux2014-multipython:latest"
121+
122+ # # env:
123+ # # JAXCI_HERMETIC_PYTHON_VERSION: 3.11
124+
125+ # # steps:
126+ # # - name: Checkout JAX Fork
127+ # # uses: actions/checkout@v3
128+ # # with:
129+ # # repository: 'google-ml-infra/jax-fork'
130+ # # path: jax-fork
131+
132+ # # - name: Install JAX Dependencies
133+ # # working-directory: jax-fork
134+ # # run: |
135+ # # python -m pip install --upgrade pip
136+ # # pip install pytest
137+ # # pip install absl-py
138+ # # pip install "jax[cuda12_pip]" # Adjust CUDA version if needed
139+ # # pip install google-benchmark
140+ # # - name: Run JAX Multiprocess GPU Test
141+ # # working-directory: jax-fork
142+ # # continue-on-error: true
143+ # # run: python -m pytest tests/multiprocess_gpu_test.py
134144
135145
136- # - name: Run HLO Module Benchmarks withg GPU in xla/tests/fuzz
137- # working-directory: xla
138- # continue-on-error: true
139- # run: |
140- # for file in xla/tests/fuzz/*.hlo; do
141- # filename=$(basename "$file")
142- # # Skip expected failed hlo files.
143- # if [[ "$filename" == "rand_000060.hlo" || "$filename" == "rand_000067.hlo" || "$filename" == "rand_000072.hlo" ]]; then
144- # echo "Skipping benchmark on $file"
145- # continue
146- # fi
147- # echo "Running benchmark on $file" &> results/"$filename".log
148- # ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning "$file" &> results/"$filename".log
149- # done
150-
151- # - name: Upload Results
152- # uses: actions/upload-artifact@v4
153- # with:
154- # name: gpu-xla-benchmarks
155- # path: xla/results
146+ # # - name: Run HLO Module Benchmarks withg GPU in xla/tests/fuzz
147+ # # working-directory: xla
148+ # # continue-on-error: true
149+ # # run: |
150+ # # for file in xla/tests/fuzz/*.hlo; do
151+ # # filename=$(basename "$file")
152+ # # # Skip expected failed hlo files.
153+ # # if [[ "$filename" == "rand_000060.hlo" || "$filename" == "rand_000067.hlo" || "$filename" == "rand_000072.hlo" ]]; then
154+ # # echo "Skipping benchmark on $file"
155+ # # continue
156+ # # fi
157+ # # echo "Running benchmark on $file" &> results/"$filename".log
158+ # # ./bazel-bin/xla/tools/multihost_hlo_runner/hlo_runner_main --device_type=gpu --use_spmd_partitioning "$file" &> results/"$filename".log
159+ # # done
160+
161+ # # - name: Upload Results
162+ # # uses: actions/upload-artifact@v4
163+ # # with:
164+ # # name: gpu-xla-benchmarks
165+ # # path: xla/results
0 commit comments