Skip to content

Commit 63606a8

Browse files
Initial split up
1 parent d4ad395 commit 63606a8

File tree

6 files changed

+79
-78
lines changed

6 files changed

+79
-78
lines changed

env/core_requirements.txt

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,27 @@
1+
#core requirements
12
setuptools>=61
2-
# This is needed to avoid issue https://yyz-gitlab.local.tenstorrent.com/devops/devops/-/issues/95
3-
# jax requires any version of optax which requires any version of chex which in turn
4-
# requires jax>=0.4.6 which conflicts with our jax == 0.3.16
5-
# TODO: Remove when jax library is upgraded
6-
chex==0.1.6
7-
dataclasses-json==0.5.7
8-
datasets==2.14.6
3+
pybind11==2.6.2
4+
numpy==1.26.4
5+
scipy>=1.8.0
6+
pandas==1.5.3
97
decorator==5.1.1
10-
flatbuffers==23.5.26
11-
# This is needed to prevent AttributeError: module 'ml_dtypes' has no attribute 'float8_e4m3b11'
12-
flax==0.9.0
13-
jax==0.4.30
148
loguru==0.5.3
15-
networkx==2.8.5
16-
numpy==1.26.4
9+
flatbuffers==23.5.26
10+
openpyxl==3.1.5
11+
GitPython==3.1.44
12+
pyinstrument>=4.1.1
13+
14+
#framework requirements
1715
onnx>=1.15.0
1816
onnxruntime>=1.16.3
19-
opencv-python-headless==4.11.0.86
20-
# This is needed to avoid issue https://yyz-gitlab.local.tenstorrent.com/devops/devops/-/issues/95
21-
pandas==1.5.3
22-
pybind11==2.6.2
23-
pyinstrument>=4.1.1
24-
scipy>=1.8.0
17+
tf2onnx==1.15.1
2518
tensorflow==2.15
2619
tensorboard==2.15
27-
tf2onnx==1.15.1
28-
transformers==4.47.0
29-
# To avoid warning during the import
30-
requests==2.28.2
3120
tflite==2.10.0
32-
ultralytics==8.3.91
33-
paddlepaddle==2.6.2
34-
paddlenlp==2.8.1
35-
aistudio-sdk==0.2.6
36-
pytorch_forecasting==1.0.0
37-
patool
38-
openpyxl==3.1.5
39-
GitPython==3.1.44
40-
mlp-mixer-pytorch==0.2.0
41-
gliner==0.2.7
42-
ase==3.24.0
43-
hippynn==0.0.3
44-
bi-lstm-crf==0.2.1
45-
peft
46-
pyclipper==1.3.0
47-
shapely==2.1.1
21+
torch @ https://download.pytorch.org/whl/cpu/torch-2.7.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version == "3.10"
22+
torchvision @ https://download.pytorch.org/whl/cpu/torchvision-0.22.0%2Bcpu-cp310-cp310-manylinux_2_28_x86_64.whl ; python_version == "3.10"
23+
24+
# JAX stack — note version constraints from chex
25+
jax==0.4.30
26+
chex==0.1.6
27+
flax==0.9.0 # Needs compatibility with jax

env/create_venv.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ if [[ "$OSTYPE" == "darwin"* ]]; then
2828
REQUIREMENTS_FILE="$CURRENT_SOURCE_DIR/mac_requirements.txt"
2929
else
3030
# TODO test on linux
31-
REQUIREMENTS_FILE="$CURRENT_SOURCE_DIR/linux_requirements.txt"
31+
REQUIREMENTS_FILE="$CURRENT_SOURCE_DIR/dev_requirements.txt"
3232
fi
3333

3434
$TTFORGE_PYTHON_VERSION -m venv $TTFORGE_VENV_DIR

env/dev_requirements.txt

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# First include all requirements from the Distribution build
2+
-r core_requirements.txt
3+
#pytest based
4+
requests==2.28.2
5+
6+
#model based
7+
transformers==4.47.0
8+
datasets==2.14.6
9+
dataclasses-json==0.5.7
10+
opencv-python-headless==4.11.0.86
11+
ultralytics==8.3.91
12+
13+
# Paddle stack
14+
paddlepaddle==2.6.2
15+
paddlenlp==2.8.1
16+
aistudio-sdk==0.2.6
17+
18+
# PyTorch ecosystem
19+
pytorch_forecasting==1.0.0
20+
21+
# Misc modeling libs
22+
mlp-mixer-pytorch==0.2.0
23+
gliner==0.2.7
24+
ase==3.24.0
25+
hippynn==0.0.3
26+
bi-lstm-crf==0.2.1
27+
peft
28+
29+
# Geometry/segmentation utils
30+
pyclipper==1.3.0
31+
shapely==2.1.1
32+
33+
clang-format==14.0.3
34+
diffusers==0.32.1
35+
pytest==6.2.4
36+
pytest-timeout==2.0.1
37+
pytest-xdist==2.5.0
38+
pytorchcv==0.0.67
39+
pytest-split
40+
seaborn
41+
scikit-image==0.20.0 # For DenseNet 121 HF XRay model
42+
segmentation_models_pytorch==0.4.0
43+
timm==1.0.9
44+
torchxrayvision==0.0.39
45+
vgg_pytorch==0.3.0
46+
python-gitlab==4.4.0
47+
tabulate==0.9.0
48+
yolov6detect==0.4.1
49+
peft==0.15.1

env/dist_requirements.txt

Lines changed: 0 additions & 4 deletions
This file was deleted.

env/linux_requirements.txt

Lines changed: 0 additions & 26 deletions
This file was deleted.

setup.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def build_forge(self, ext):
6262
with open("env/core_requirements.txt", "r") as f:
6363
core_requirements = f.read().splitlines()
6464

65-
with open("env/linux_requirements.txt", "r") as f:
66-
linux_requirements = [r for r in f.read().splitlines() if not r.startswith("-r")]
65+
with open("env/dev-requirements.txt", "r") as f:
66+
dev_requirements = [r for r in f.read().splitlines() if not r.startswith("-r")]
6767

6868

6969
def collect_model_requirements(requirements_root: str) -> list[str]:
@@ -152,10 +152,9 @@ def collect_model_requirements(requirements_root: str) -> list[str]:
152152
return [pkg + ver if ver else pkg for pkg, ver in sorted(final_requirements.items())]
153153

154154

155-
model_requirements_root = "forge/test/models"
156-
model_requirements = collect_model_requirements(model_requirements_root)
155+
model_requirements = collect_model_requirements("forge/test/models")
157156

158-
requirements = core_requirements + linux_requirements + model_requirements
157+
requirements = core_requirements + dev_requirements + model_requirements
159158

160159
# Compute a dynamic version from git
161160
short_hash = subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]).decode("ascii").strip()
@@ -171,14 +170,17 @@ def collect_model_requirements(requirements_root: str) -> list[str]:
171170
# Find packages as before
172171
packages = [p for p in find_packages("forge") if not p.startswith("test")]
173172

174-
175173
setup(
176174
name="tt_forge_fe",
177175
version=version,
178-
install_requires=requirements,
176+
install_requires=core_requirements,
177+
extras_require={
178+
"dev": dev_requirements,
179+
"model": dev_requirements + model_requirements,
180+
},
179181
packages=packages,
180182
package_dir={"forge": "forge/forge"},
181-
ext_modules=[forge_c],
183+
ext_modules=[TTExtension("forge")],
182184
cmdclass={"build_ext": CMakeBuild},
183185
long_description=long_description,
184186
long_description_content_type="text/markdown",

0 commit comments

Comments
 (0)