Skip to content

Commit 671e9d5

Browse files
authored
Add NFS support for model tests (#857)
* Adds NFS support for model tests Adds function to use cached model inputs with url backup Removes model validation that checks if we reference forked branch - this would often cause rate limit errors * Fixed caching so it will work on machines without NFS
1 parent d512fa3 commit 671e9d5

File tree

9 files changed

+55
-18
lines changed

9 files changed

+55
-18
lines changed

.github/workflows/before_merge.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ jobs:
4545

4646
model-tests:
4747
needs: lowering-tests
48-
runs-on: ["in-service"]
48+
runs-on: ["in-service", "nfs"]
4949
env:
5050
pytest_verbosity: 0
5151
pytest_report_title: "⭐️ Model Tests - Group ${{ matrix.group }}"
52+
TORCH_HOME: /mnt/tt-metal-pytorch-cache/.cache/torch
53+
HF_HOME: /mnt/tt-metal-pytorch-cache/.cache/huggingface
5254
strategy:
5355
matrix: # Need to find a way to replace this with a generator
5456
group: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40]

tests/models/detr/test_detr.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def _load_model(self):
1616
The model is from https://github.com/facebookresearch/detr
1717
"""
1818
# Model
19-
model = torch.hub.load("facebookresearch/detr:main", "detr_resnet50", pretrained=True).to(torch.bfloat16)
19+
model = torch.hub.load("facebookresearch/detr:main", "detr_resnet50", pretrained=True, skip_validation=True).to(
20+
torch.bfloat16
21+
)
2022
return model
2123

2224
def _load_inputs(self):

tests/models/hardnet/test_hardnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
class ThisTester(ModelTester):
1616
def _load_model(self):
17-
model = torch.hub.load("PingoLH/Pytorch-HarDNet", "hardnet68", pretrained=False)
17+
model = torch.hub.load("PingoLH/Pytorch-HarDNet", "hardnet68", pretrained=False, skip_validation=True)
1818
checkpoint = "https://github.com/PingoLH/Pytorch-HarDNet/raw/refs/heads/master/hardnet68.pth"
1919
model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False, map_location="cpu"))
2020
model = model.to(torch.bfloat16)

tests/models/unet/test_unet.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55

66
import numpy as np
77
from PIL import Image
8+
from os import path
89
from torchvision import transforms
910
import requests
1011
import torch
1112
import pytest
12-
from tests.utils import ModelTester
13+
from tests.utils import ModelTester, get_cached_image_or_reload
1314

1415

1516
class ThisTester(ModelTester):
@@ -21,13 +22,17 @@ def _load_model(self):
2122
out_channels=1,
2223
init_features=32,
2324
pretrained=True,
25+
skip_validation=True,
2426
)
2527
model = model.to(torch.bfloat16)
2628
return model
2729

2830
def _load_inputs(self):
29-
url = "https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png"
30-
input_image = Image.open(requests.get(url, stream=True).raw)
31+
image_file = get_cached_image_or_reload(
32+
relative_cache_path="inputs/TCGA_CS_4944.png",
33+
url="https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png",
34+
)
35+
input_image = Image.open(image_file)
3136
m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))
3237
preprocess = transforms.Compose(
3338
[

tests/models/unet_brain/test_unet_brain.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from PIL import Image
88
from torchvision import transforms
99
import pytest
10-
from tests.utils import ModelTester
10+
from tests.utils import ModelTester, get_cached_image_or_reload
1111

1212

1313
class ThisTester(ModelTester):
@@ -23,19 +23,16 @@ def _load_model(self):
2323
out_channels=1,
2424
init_features=32,
2525
pretrained=True,
26+
skip_validation=True,
2627
)
2728
model = model.to(torch.bfloat16)
2829
return model
2930

3031
def _load_inputs(self):
31-
url, filename = (
32-
"https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png",
33-
"TCGA_CS_4944.png",
32+
filename = get_cached_image_or_reload(
33+
relative_cache_path="inputs/TCGA_CS_4944.png",
34+
url="https://github.com/mateuszbuda/brain-segmentation-pytorch/raw/master/assets/TCGA_CS_4944.png",
3435
)
35-
try:
36-
urllib.URLopener().retrieve(url, filename)
37-
except:
38-
urllib.request.urlretrieve(url, filename)
3936

4037
input_image = Image.open(filename)
4138
m, s = np.mean(input_image, axis=(0, 1)), np.std(input_image, axis=(0, 1))

tests/models/yolov5/test_yolov5.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
class ThisTester(ModelTester):
1818
def _load_model(self):
1919
# Model
20-
model = torch.hub.load("ultralytics/yolov5", "yolov5s", pretrained=True, autoshape=False, device="cpu")
20+
model = torch.hub.load(
21+
"ultralytics/yolov5", "yolov5s", pretrained=True, autoshape=False, device="cpu", skip_validation=True
22+
)
2123
return model.to(torch.bfloat16)
2224

2325
def _load_inputs(self):

tests/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import torch
55
import numpy as np
66
import re
7+
import requests
8+
from os import path, makedirs
79
from collections.abc import Mapping, Sequence
810
from typing import List, Dict, Tuple
911

@@ -147,6 +149,33 @@ def test_model(self, as_ttnn=False, option=None):
147149
raise ValueError(f"Current mode is not supported: {self.mode}")
148150

149151

152+
def get_absolute_cache_path(path_relative_to_cache):
153+
# convenience method to use NFS if available
154+
nfs_cache_base = "/mnt/tt-metal-pytorch-cache/.cache"
155+
if path.exists(nfs_cache_base):
156+
return path.join(nfs_cache_base, path_relative_to_cache)
157+
else:
158+
absolute_cache_base = path.expanduser("~/.cache")
159+
return path.join(absolute_cache_base, path_relative_to_cache)
160+
161+
162+
def get_cached_image_or_reload(relative_cache_path, url):
163+
absolute_cache_path = get_absolute_cache_path(relative_cache_path)
164+
165+
if path.exists(absolute_cache_path):
166+
return absolute_cache_path
167+
168+
dir, _ = path.split(absolute_cache_path)
169+
makedirs(dir, exist_ok=True)
170+
171+
image_file = requests.get(url, stream=True)
172+
with open(absolute_cache_path, "wb") as file:
173+
for chunk in image_file.iter_content(chunk_size=8192):
174+
file.write(chunk)
175+
176+
return absolute_cache_path
177+
178+
150179
# Testing utils copied from tt-metal/tests/ttnn/utils_for_testing.py
151180
def comp_pcc(golden, calculated, pcc=0.99):
152181
golden = torch.Tensor(golden)

tools/run_torchvision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def run_model(
1717
device=None,
1818
):
1919
if model_name == "dinov2_vits14":
20-
m = torch.hub.load("facebookresearch/dinov2", model_name)
20+
m = torch.hub.load("facebookresearch/dinov2", model_name, skip_validation=True)
2121
else:
2222
try:
2323
m = torchvision.models.get_model(model_name, pretrained=True)

tools/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77

88
def get_model(model_name):
99
if model_name == "dinov2_vits14":
10-
m = torch.hub.load("facebookresearch/dinov2", model_name)
10+
m = torch.hub.load("facebookresearch/dinov2", model_name, skip_validation=True)
1111
elif model_name == "detr_resnet50":
12-
m = torch.hub.load("facebookresearch/detr:main", "detr_resnet50", pretrained=True)
12+
m = torch.hub.load("facebookresearch/detr:main", "detr_resnet50", pretrained=True, skip_validation=True)
1313
else:
1414
try:
1515
m = torchvision.models.get_model(model_name, pretrained=True)

0 commit comments

Comments
 (0)