Skip to content

Commit d9c0466

Browse files
authored
Allow a custom pyproject.toml template (#75)
* Allow a custom pyproject.toml template This update enables users to provide a custom pyproject.toml file as a template for the files generated by seed-env. Here is how the new functionality works: * The tool can now accept a user-provided pyproject.toml template by using a `--template-pyproject-toml` flag. * By default, if a pyproject.toml file is present at the project's root when the seed-env CLI is called, the tool will automatically use it as the template. * The template file used for input cannot be the same as the pyproject.toml file generated in the --output-dir. * Disallow a pre-existing pyproject.toml in the --output-dir. * Fix ruff format
1 parent fbf7e3e commit d9c0466

File tree

7 files changed

+106
-32
lines changed

7 files changed

+106
-32
lines changed

python_seed_env/src/seed_env/cli.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ def main():
8383
)
8484

8585
# --- Common Arguments ---
86+
parser.add_argument(
87+
"--template-pyproject-toml",
88+
type=str,
89+
default=None,
90+
help="Path to a custom pyproject.toml file to use as a template.",
91+
)
8692
parser.add_argument(
8793
"--seed-config",
8894
type=str,
@@ -207,6 +213,7 @@ def main():
207213
hardware=args.hardware,
208214
build_pypi_package=args.build_pypi_package,
209215
output_dir=args.output_dir,
216+
template_pyproject_toml=args.template_pyproject_toml,
210217
)
211218
# Core function
212219
host_env_seeder.seed_environment()

python_seed_env/src/seed_env/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,12 +52,9 @@
5252
"nvidia-nvvm",
5353
"jax-cuda12-plugin",
5454
"jax-cuda13-plugin",
55-
# "jax-cuda12-plugin[with-cuda]",
5655
"jax-cuda12-pjrt",
5756
"jax-cuda13-pjrt",
5857
"transformer-engine",
59-
# "transformer-engine[jax]",
60-
# "transformer-engine[pytorch]",
6158
]
6259

6360
TPU_SPECIFIC_DEPS = [

python_seed_env/src/seed_env/core.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,18 @@
1616

1717
import os
1818
import logging
19+
import shutil
1920
import yaml
2021
from importlib.resources import files
2122
from seed_env.seeder import Seeder
2223
from seed_env.utils import generate_minimal_pyproject_toml
2324
from seed_env.git_utils import download_remote_git_file
2425
from seed_env.uv_utils import (
26+
set_exact_python_requirement_in_project_toml,
2527
build_seed_env,
2628
build_pypi_package,
2729
merge_project_toml_files,
30+
replace_dependencies_in_project_toml,
2831
)
2932

3033
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -59,13 +62,15 @@ def __init__(
5962
hardware: str,
6063
build_pypi_package: bool,
6164
output_dir: str,
65+
template_pyproject_toml: str = None,
6266
):
6367
self.host_name = host_name
6468
self.host_source_type = host_source_type
6569
self.host_github_org_repo = host_github_org_repo
6670
self.host_requirements_file_path = host_requirements_file_path
6771
self.host_commit = host_commit
6872
self.seed_config_input = seed_config
73+
self.template_pyproject_toml = template_pyproject_toml
6974
self.loaded_seed_config = None
7075
self.seed_tag_or_commit = seed_tag_or_commit
7176
self.python_versions = python_version.split(",")
@@ -143,6 +148,31 @@ def seed_environment(self):
143148
os.makedirs(self.output_dir, exist_ok=True)
144149
self.output_dir = os.path.abspath(self.output_dir)
145150

151+
# Determine the template for pyproject.toml. The explicit CLI argument takes precedence.
152+
template_path = self.template_pyproject_toml
153+
if not template_path and os.path.isfile("./pyproject.toml"):
154+
template_path = os.path.abspath("./pyproject.toml")
155+
logging.info(
156+
f"Found pyproject.toml in the current directory. Using it as a template: {template_path}"
157+
)
158+
159+
# Pre-flight check: Ensure the output directory root is clean of a pyproject.toml, as we will generate one.
160+
final_pyproject_path = os.path.join(self.output_dir, "pyproject.toml")
161+
if os.path.isfile(final_pyproject_path):
162+
# Check for the specific edge case where the output directory is the project root
163+
# and the existing pyproject.toml is the one we are using as a template.
164+
if template_path and os.path.samefile(template_path, final_pyproject_path):
165+
raise FileExistsError(
166+
f"The output directory ('{self.output_dir}') contains a 'pyproject.toml', which was found to be used as a template. "
167+
"Running this would overwrite the original template file. Please use a different --output-dir; or move the"
168+
"existing pyproject.toml to a different location and use the --template-pyproject-toml flag to specify its new location."
169+
)
170+
# General case: the output directory contains a pre-existing pyproject.toml.
171+
raise FileExistsError(
172+
f"A pyproject.toml file already exists in the output directory: {self.output_dir}. "
173+
"Please provide a clean directory or remove the file to avoid accidentaly overwriting it."
174+
)
175+
146176
# Create a directory for storing the downloaded requirements file
147177
self.download_dir = "downloaded_base_and_seed_requirements"
148178
os.makedirs(self.download_dir, exist_ok=True)
@@ -179,30 +209,36 @@ def seed_environment(self):
179209
f"Using {self.seeder.pypi_project_name} at tag/commit {self.seed_tag_or_commit} on {self.seeder.github_org_repo} as seed"
180210
)
181211

182-
# Remove pyproject.toml if it exists, as we will generate a new one with merge_project_toml_files
183-
pyproject_file = os.path.join(self.output_dir, "pyproject.toml")
184-
if os.path.isfile(pyproject_file):
185-
os.remove(pyproject_file)
186-
logging.info(f"Removed existing pyproject.toml file: {pyproject_file}")
187-
188212
versioned_project_toml_files = []
189213
for python_version in self.python_versions:
190214
# Generate a subdir for each python version
191215
versioned_output_dir = (
192216
self.output_dir + "/python" + python_version.replace(".", "_")
193217
)
194-
versioned_project_toml_files.append(versioned_output_dir + "/pyproject.toml")
195218
os.makedirs(versioned_output_dir, exist_ok=True)
219+
versioned_pyproject_path = os.path.join(versioned_output_dir, "pyproject.toml")
220+
versioned_project_toml_files.append(versioned_pyproject_path)
196221

197222
# 3. Download the seed lock file for the specified Python version
198223
SEED_LOCK_FILE = os.path.abspath(
199224
self.seeder.download_seed_lock_requirement(python_version)
200225
)
201226

202-
# 4. Generate a minimal pyproject.toml file for the specified Python version to the output directory
203-
generate_minimal_pyproject_toml(
204-
self.host_name, python_version, versioned_output_dir
205-
)
227+
# 4. Generate a pyproject.toml file for the specified Python version.
228+
if template_path:
229+
logging.info(f"Using template {template_path} for Python {python_version}")
230+
shutil.copy(template_path, versioned_pyproject_path)
231+
# Clear any existing dependencies from the template to start fresh.
232+
replace_dependencies_in_project_toml([], versioned_pyproject_path)
233+
# Update the python version in the copied template to be specific for this build pass.
234+
set_exact_python_requirement_in_project_toml(
235+
python_version, versioned_pyproject_path
236+
)
237+
else:
238+
logging.info(f"Generating minimal pyproject.toml for Python {python_version}")
239+
generate_minimal_pyproject_toml(
240+
self.host_name, python_version, versioned_output_dir
241+
)
206242

207243
# Construct the host lock file name
208244
HOST_LOCK_FILE_NAME = f"{self.host_name.replace('-', '_')}_requirements_lock_{python_version.replace('.', '_')}.txt"
@@ -220,6 +256,12 @@ def seed_environment(self):
220256
merge_project_toml_files(versioned_project_toml_files, self.output_dir)
221257

222258
# 6. Build pypi package
259+
# TODO(kanglant): Assume where the seed-env cli is called is the project root
260+
# and move the generated pyproject.toml to the project root? If there is an old
261+
# pyproject.toml here, this behavior will overwrite it. I think this is risky.
262+
# Another option is to copy the source files, specified in the new pyproject.toml,
263+
# from the project root to the output_dir folder. Then perform a fully isolated
264+
# build at the output_dir.
223265
if self.build_pypi_package:
224266
# Use the new pyproject.toml file at the output dir to build the package.
225267
build_pypi_package(self.output_dir)

python_seed_env/src/seed_env/utils.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,21 +98,15 @@ def generate_minimal_pyproject_toml(
9898
9999
[project]
100100
name = "{project_name}"
101-
description = "{project_name} is a simple, performant and scalable Jax LLM!"
102101
version = "0.0.1"
103-
readme = "README.md"
104102
license = "Apache-2.0"
105-
license-files = ["LICENSE"]
106-
keywords = ["llm", "jax", "llama", "mistral", "mixtral", "gemma", "deepseek"]
107103
requires-python = "=={python_version}.*"
108104
dependencies = [
109105
]
110106
classifiers = [
111-
"Development Status :: 4 - Beta",
112107
"Programming Language :: Python",
113108
]
114109
115-
# TODO(kanglant): Remove this once maxtext src-layout restructure done.
116110
[tool.hatch.build.targets.wheel]
117111
packages = ["{project_name}"]
118112

python_seed_env/src/seed_env/uv_utils.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,6 @@ def build_seed_env(
8383
output_dir,
8484
"-r",
8585
seed_lock_file,
86-
"--extra-index-url",
87-
"https://pypi.nvidia.com",
8886
]
8987
run_command(command)
9088

@@ -103,8 +101,6 @@ def build_seed_env(
103101
]
104102
run_command(command)
105103

106-
_remove_hardware_specific_deps(hardware, pyproject_file, output_dir)
107-
108104
command = [
109105
"uv",
110106
"export",
@@ -241,7 +237,10 @@ def replace_dependencies_in_project_toml(new_deps_list: list, filepath: str):
241237
This function reads the specified pyproject.toml file, finds the existing project dependencies array,
242238
and replaces it with the provided new_deps_list list. The updated content is then written back to the file.
243239
"""
244-
new_deps = 'dependencies = [\n "' + '",\n "'.join(new_deps_list) + '",\n]'
240+
if new_deps_list:
241+
new_deps = 'dependencies = [\n "' + '",\n "'.join(new_deps_list) + '",\n]'
242+
else:
243+
new_deps = "dependencies = []"
245244

246245
dependencies_regex = re.compile(
247246
r"^dependencies\s*=\s*\[(\n+\s*.*,\s*)*[\n\r]*\]", re.MULTILINE
@@ -277,6 +276,37 @@ def replace_python_requirement_in_project_toml(min_python: str, filepath: str):
277276
f.write(new_content)
278277

279278

279+
def set_exact_python_requirement_in_project_toml(python_version: str, filepath: str):
280+
"""
281+
Sets or adds the requires-python section in a pyproject.toml file to an exact version series.
282+
283+
Args:
284+
python_version (str): The target Python version (e.g., '3.12').
285+
filepath (str): Path to the pyproject.toml file to update.
286+
"""
287+
python_req_regex = re.compile(r'requires-python\s*=\s*".*?"')
288+
project_header_regex = re.compile(r"^\[project\]", re.MULTILINE)
289+
new_requires_line = f'requires-python = "=={python_version}.*"'
290+
291+
with open(filepath, "r", encoding="utf-8") as f:
292+
content = f.read()
293+
294+
if python_req_regex.search(content):
295+
# If 'requires-python' exists, substitute it.
296+
new_content = python_req_regex.sub(new_requires_line, content)
297+
elif project_header_regex.search(content):
298+
# If it doesn't exist but [project] does, add it after the [project] header.
299+
new_content = project_header_regex.sub(
300+
f"[project]\n{new_requires_line}", content, count=1
301+
)
302+
else:
303+
logging.error("No project table found in the template pyproject.toml.")
304+
raise
305+
306+
with open(filepath, "w", encoding="utf-8") as f:
307+
f.write(new_content)
308+
309+
280310
def lock_to_lower_bound_project(host_lock_file: str, pyproject_toml: str):
281311
"""
282312
Updates the dependencies in a pyproject.toml file to use lower-bound versions based on a lock file.

python_seed_env/tests/test_core.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,8 @@ def test_environment_seeder_init_invalid_seed():
5757
def test_seed_environment_remote(mocker, tmp_path):
5858
# Mock all external dependencies
5959
mock_download = mocker.patch(
60-
"seed_env.core.download_remote_git_file", return_value=str(tmp_path / "host.txt")
61-
)
62-
mock_generate_pyproject = mocker.patch(
63-
"seed_env.core.generate_minimal_pyproject_toml"
60+
"seed_env.core.download_remote_git_file",
61+
return_value=str(tmp_path / "host.txt"),
6462
)
6563
mock_merge_project_toml_files = mocker.patch("seed_env.core.merge_project_toml_files")
6664
mock_build_env = mocker.patch("seed_env.core.build_seed_env")
@@ -74,6 +72,11 @@ def test_seed_environment_remote(mocker, tmp_path):
7472
)
7573
mocker.patch("seed_env.core.Seeder", return_value=mock_seeder_instance)
7674

75+
# 4. Instantiate and run the seeder.
76+
template_toml_path = tmp_path / "pyproject.toml"
77+
template_toml_path.write_text(
78+
'[project]\nname = "myproj"\nreadme = "README.md"\n[tool.hatch.build.targets.wheel]\npackages = ["myproj"]'
79+
)
7780
seeder = EnvironmentSeeder(
7881
host_name="myproj",
7982
host_source_type="remote",
@@ -86,16 +89,17 @@ def test_seed_environment_remote(mocker, tmp_path):
8689
hardware="cpu",
8790
build_pypi_package=True,
8891
output_dir=str(tmp_path / "output"),
92+
template_pyproject_toml=str(template_toml_path),
8993
)
9094
seeder.seed_environment()
9195

9296
# Assert all mocks were called
9397
assert mock_download.called
94-
assert mock_generate_pyproject.called
98+
# assert mock_generate_pyproject.called
9599
assert mock_build_env.called
96100
assert mock_merge_project_toml_files.called
97101
assert mock_build_pypi.called
98-
assert mock_seeder_instance.download_seed_lock_requirement.called
102+
mock_seeder_instance.download_seed_lock_requirement.assert_called_with("3.12")
99103

100104

101105
def test_seed_environment_local_file_not_found(mocker, tmp_path):

python_seed_env/tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def test_build_seed_env_calls_run_command(mocker, tmp_path):
159159
assert mock_os_remove.called
160160
assert mock_lock_to_lower_bound_project.called
161161
# Check for the expected uv remove command
162-
assert mock_remove_hardware_specific_deps.call_count == 2
162+
assert mock_remove_hardware_specific_deps.call_count == 1
163163

164164
# Collect all commands passed to run_command
165165
commands = [call.args[0] for call in mock_run_command.call_args_list]

0 commit comments

Comments
 (0)