@@ -156,27 +156,23 @@ def set_clang_as_compiler(bazel_command: command.CommandBuilder, clang_path: str
156156 else :
157157 logger .debug ("Could not find path to Clang. Continuing without Clang." )
158158
159- def adjust_paths_for_windows (wheel_binary : str , output_dir : str , arch : str ) -> tuple [str , str , str ]:
159+ def adjust_paths_for_windows (output_dir : str , arch : str ) -> tuple [str , str , str ]:
160160 """
161161 Adjusts the paths to be compatible with Windows.
162162 Args:
163- wheel_binary: The path to the wheel binary that was built by Bazel.
164163 output_dir: The output directory for the wheel.
165164 arch: The architecture of the host system.
166165 Returns:
167166 A tuple of the adjusted paths.
168167 """
169168 logger .debug ("Adjusting paths for Windows..." )
170- # On Windows, the wheel binary has a .exe extension. and the path needs
171- # to be adjusted to use backslashes.
172- wheel_binary = wheel_binary .replace ("/" , "\\ " ) + ".exe"
173169 output_dir = output_dir .replace ("/" , "\\ " )
174170
175171 # Change to upper case to match the case in
176172 # "jax/tools/build_utils.py" for Windows.
177173 arch = arch .upper ()
178174
179- return (wheel_binary , output_dir , arch )
175+ return (output_dir , arch )
180176
181177def parse_and_append_bazel_options (bazel_command : command .CommandBuilder , bazel_options : str ):
182178 """
@@ -590,7 +586,7 @@ async def main():
590586 await executor .run (bazel_command .command , args .dry_run )
591587 sys .exit (0 )
592588
593- bazel_command .append ("build " )
589+ bazel_command .append ("run " )
594590
595591 if args .enable_native_arch_features :
596592 logging .debug ("Enabling native target CPU features." )
@@ -672,48 +668,42 @@ async def main():
672668 build_target , wheel_binary = ARTIFACT_BUILD_TARGET_DICT [args .command ]
673669 bazel_command .append (build_target )
674670
675- # Execute the Bazel command.
676- await executor .run (bazel_command .command , args .dry_run )
677-
678- # Construct the wheel build command.
679- logger .info ("Constructing wheel build command..." )
680-
681671 # Read output directory. Default is store the artifacts in the "dist/"
682672 # directory in JAX's GitHub repository root.
683673 output_dir = args .output_dir
684674
685675 # If running on Windows, adjust the paths for compatibility.
686676 if os_name == "windows" :
687- wheel_binary , output_dir , arch = adjust_paths_for_windows (
688- wheel_binary , output_dir , arch
677+ output_dir , arch = adjust_paths_for_windows (
678+ output_dir , arch
689679 )
690680
691681 logger .debug ("Storing artifacts in %s" , output_dir )
692682
693- run_wheel_binary = command . CommandBuilder ( wheel_binary )
683+ bazel_command . append ( "--" )
694684
695685 if args .editable :
696686 logger .debug ("Building an editable build." )
697687 output_dir = os .path .join (output_dir , args .command )
698- run_wheel_binary .append ("--editable" )
688+ bazel_command .append ("--editable" )
699689
700- run_wheel_binary .append (f"--output_path={ output_dir } " )
701- run_wheel_binary .append (f"--cpu={ arch } " )
690+ bazel_command .append (f"--output_path={ output_dir } " )
691+ bazel_command .append (f"--cpu={ arch } " )
702692
703693 if "cuda" in args .command :
704- run_wheel_binary .append ("--enable-cuda=True" )
694+ bazel_command .append ("--enable-cuda=True" )
705695 major_cuda_version = args .cuda_version .split ("." )[0 ]
706- run_wheel_binary .append (f"--platform_version={ major_cuda_version } " )
696+ bazel_command .append (f"--platform_version={ major_cuda_version } " )
707697
708698 if "rocm" in args .command :
709- run_wheel_binary .append ("--enable-rocm=True" )
710- run_wheel_binary .append (f"--platform_version={ args .rocm_version } " )
699+ bazel_command .append ("--enable-rocm=True" )
700+ bazel_command .append (f"--platform_version={ args .rocm_version } " )
711701
712702 jaxlib_git_hash = get_jaxlib_git_hash ()
713- run_wheel_binary .append (f"--jaxlib_git_hash={ jaxlib_git_hash } " )
703+ bazel_command .append (f"--jaxlib_git_hash={ jaxlib_git_hash } " )
714704
715705 # Execute the wheel build command.
716- await executor .run (run_wheel_binary .command , args .dry_run )
706+ await executor .run (bazel_command .command , args .dry_run )
717707
718708if __name__ == "__main__" :
719709 asyncio .run (main ())
0 commit comments