Skip to content

Commit bfef9f5

Browse files
committed
fix invalid merge
1 parent e7e5e06 commit bfef9f5

File tree

4 files changed

+11
-64
lines changed

4 files changed

+11
-64
lines changed

ci/envs/default.env

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,4 +66,4 @@ export JAXCI_TPU_CORES=${JAXCI_TPU_CORES:-}
6666
# JAXCI_PYTHON points to the Python interpreter to use for installing JAX wheels
6767
# on the system. By default, it is set to match the version of the hermetic
6868
# Python used by Bazel for building the wheels.
69-
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}
69+
export JAXCI_PYTHON=${JAXCI_PYTHON:-python${JAXCI_HERMETIC_PYTHON_VERSION}}

ci/run_pytest_gpu.sh

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
# ==============================================================================
16-
<<<<<<< HEAD
17-
# Runs Pyest CPU tests. Requires all jaxlib, jax-cuda-plugin, and jax-cuda-pjrt
18-
=======
1916
# Runs Pyest CPU tests. Requires the jaxlib, jax-cuda-plugin, and jax-cuda-pjrt
20-
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
2117
# wheels to be present inside $JAXCI_OUTPUT_DIR (../dist)
2218
#
2319
# -e: abort script if one command fails
@@ -27,39 +23,17 @@
2723
# -o allexport: export all functions and variables to be available to subscripts
2824
set -exu -o history -o allexport
2925

30-
<<<<<<< HEAD
31-
# Inherit default JAXCI environment variables.
32-
source ci/envs/default.env
33-
34-
# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels on the system.
35-
=======
3626
# Source default JAXCI environment variables.
3727
source ci/envs/default.env
3828

3929
# Install jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels inside the
4030
# $JAXCI_OUTPUT_DIR directory on the system.
41-
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
4231
echo "Installing wheels locally..."
4332
source ./ci/utilities/install_wheels_locally.sh
4433

4534
# Set up the build environment.
4635
source "ci/utilities/setup_build_environment.sh"
4736

48-
<<<<<<< HEAD
49-
export PY_COLORS=1
50-
export JAX_SKIP_SLOW_TESTS=true
51-
52-
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
53-
54-
nvidia-smi
55-
export NCCL_DEBUG=WARN
56-
export TF_CPP_MIN_LOG_LEVEL=0
57-
58-
echo "Running GPU tests..."
59-
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
60-
export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
61-
"$JAXCI_PYTHON" -m pytest -n 8 --tb=short --maxfail=20 \
62-
=======
6337
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
6438

6539
nvidia-smi
@@ -80,7 +54,6 @@ export XLA_FLAGS=--xla_gpu_force_compilation_parallelism=1
8054

8155
echo "Running GPU tests..."
8256
"$JAXCI_PYTHON" -m pytest -n $num_processes --tb=short --maxfail=20 \
83-
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
8457
tests examples \
8558
--deselect=tests/multi_device_test.py::MultiDeviceTest::test_computation_follows_data \
8659
--deselect=tests/xmap_test.py::XMapTest::testCollectivePermute2D \

ci/run_pytest_tpu.sh

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,29 +33,27 @@ source ./ci/utilities/install_wheels_locally.sh
3333
# Set up the build environment.
3434
source "ci/utilities/setup_build_environment.sh"
3535

36-
export PY_COLORS=1
37-
export JAX_SKIP_SLOW_TESTS=true
38-
3936
"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"
40-
4137
"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)'
4238
"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)'
4339
"$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
44-
strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on'
40+
strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
4541
"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'
4642

47-
echo "Running TPU tests..."
43+
# Set up common test environment variables
44+
export PY_COLORS=1
45+
export JAX_SKIP_SLOW_TESTS=true
4846
export JAX_PLATFORMS=tpu,cpu
49-
# Run single-accelerator tests in parallel
50-
export JAX_ENABLE_TPU_XDIST=true
47+
# End of common test environment variable setup
5148

52-
"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
49+
echo "Running TPU tests..."
50+
# Run single-accelerator tests in parallel
51+
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
5352
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
5453
--maxfail=20 -m "not multiaccelerator" tests examples
5554

5655
# Run Pallas printing tests, which need to run with I/O capturing disabled.
57-
export TPU_STDERR_LOG_LEVEL=0
58-
"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
56+
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
5957

6058
# Run multi-accelerator across all chips
6159
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests

ci/utilities/convert_msys_paths_to_win_paths.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616
Converts MSYS Linux-like paths stored in env variables to Windows paths.
1717
1818
This is necessary on Windows, because some applications do not understand/handle
19-
<<<<<<< HEAD
20-
Linux-like paths MSYS uses, for example, Docker.
21-
=======
2219
Linux-like paths MSYS uses, for example, Bazel.
23-
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
2420
"""
2521
import argparse
2622
import os
@@ -47,8 +43,6 @@ def msys_to_windows_path(msys_path):
4743
print(f"Error converting path: {e}")
4844
return None
4945

50-
<<<<<<< HEAD
51-
=======
5246
def should_convert(var: str,
5347
convert: list[str] | None):
5448
"""Check the variable name against convert list"""
@@ -57,24 +51,15 @@ def should_convert(var: str,
5751
else:
5852
return False
5953

60-
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
6154
def main(parsed_args: argparse.Namespace):
6255
converted_paths = {}
6356

6457
for var, value in os.environ.items():
65-
<<<<<<< HEAD
66-
if (parsed_args.blacklist and var in parsed_args.blacklist) or not value:
67-
continue
68-
if "_DIR" in var or (args.whitelist and var in parsed_args.whitelist):
69-
converted_path = msys_to_windows_path(value)
70-
converted_paths[var] = converted_path
71-
=======
7258
if not value or not should_convert(var,
7359
parsed_args.convert):
7460
continue
7561
converted_path = msys_to_windows_path(value)
7662
converted_paths[var] = converted_path
77-
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
7863

7964
var_str = '\n'.join(f'export {k}="{v}"'
8065
for k, v in converted_paths.items())
@@ -86,19 +71,10 @@ def main(parsed_args: argparse.Namespace):
8671
if __name__ == '__main__':
8772
parser = argparse.ArgumentParser(description=(
8873
'Convert MSYS paths in environment variables to Windows paths.'))
89-
<<<<<<< HEAD
90-
parser.add_argument('--blacklist',
91-
nargs='*',
92-
help='List of variables to ignore')
93-
parser.add_argument('--whitelist',
94-
nargs='*',
95-
help='List of variables to include')
96-
=======
9774
parser.add_argument('--convert',
9875
nargs='+',
9976
required=True,
10077
help='Space separated list of environment variables to convert. E.g: --convert env_var1 env_var2')
101-
>>>>>>> 5ade371c88a1f879556ec29867b173da49ae57f0
10278
args = parser.parse_args()
10379

104-
main(args)
80+
main(args)

0 commit comments

Comments
 (0)