Skip to content

Commit 6dc8c68

Browse files
committed
init
1 parent 7dbb864 commit 6dc8c68

File tree

3 files changed

+12
-8
lines changed

3 files changed

+12
-8
lines changed

.github/scripts/m1_script.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
#!/bin/bash
22

3-
export BUILD_VERSION=0.4.0
3+
export TENSORDICT_BUILD_VERSION=0.4.0

.github/workflows/wheels.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ jobs:
3232
run: |
3333
export PATH="/opt/python/${{ matrix.python_version[1] }}/bin:$PATH"
3434
python3 -mpip install wheel
35-
BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
35+
TENSORDICT_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
3636
# NB: wheels have the linux_x86_64 tag so we rename to manylinux1
3737
# find . -name 'dist/*whl' -exec bash -c ' mv $0 ${0/linux/manylinux1}' {} \;
3838
# pytorch/pytorch binaries are also manylinux_2_17 compliant but they
@@ -72,7 +72,7 @@ jobs:
7272
shell: bash
7373
run: |
7474
python3 -mpip install wheel
75-
BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
75+
TENSORDICT_BUILD_VERSION=0.4.0 python3 setup.py bdist_wheel
7676
- name: Upload wheel for the test-wheel job
7777
uses: actions/upload-artifact@v2
7878
with:

setup.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def parse_args(argv: List[str]) -> argparse.Namespace:
4444

4545
def get_version():
4646
version = (ROOT_DIR / "version.txt").read_text().strip()
47-
if os.getenv("BUILD_VERSION"):
48-
version = os.getenv("BUILD_VERSION")
47+
if os.getenv("TENSORDICT_BUILD_VERSION"):
48+
version = os.getenv("TENSORDICT_BUILD_VERSION")
4949
elif sha != "Unknown":
5050
version += "+" + sha[:7]
5151
return version
@@ -62,11 +62,13 @@ def write_version_file(version):
6262
f.write(f"git_version = {repr(sha)}\n")
6363

6464

65-
def _get_pytorch_version(is_nightly):
65+
def _get_pytorch_version(is_nightly, is_local):
6666
# if "PYTORCH_VERSION" in os.environ:
6767
# return f"torch=={os.environ['PYTORCH_VERSION']}"
6868
if is_nightly:
6969
return "torch>=2.4.0.dev"
70+
if is_local:
71+
return "torch"
7072
return "torch>=2.3.0"
7173

7274

@@ -153,9 +155,11 @@ def _main(argv):
153155

154156
write_version_file(version)
155157
logging.info(f"Building wheel {package_name}-{version}")
156-
logging.info(f"BUILD_VERSION is {os.getenv('BUILD_VERSION')}")
158+
BUILD_VERSION = os.getenv("TENSORDICT_BUILD_VERSION")
159+
logging.info(f"TENSORDICT_BUILD_VERSION is {BUILD_VERSION}")
160+
local_build = BUILD_VERSION is None
157161

158-
pytorch_package_dep = _get_pytorch_version(is_nightly)
162+
pytorch_package_dep = _get_pytorch_version(is_nightly, local_build)
159163
logging.info("-- PyTorch dependency:", pytorch_package_dep)
160164

161165
long_description = (ROOT_DIR / "README.md").read_text()

0 commit comments

Comments
 (0)