-
Notifications
You must be signed in to change notification settings - Fork 28
Open
Description
Summary
Standardize all model and input loading in tt-forge-onnx to use the tt-forge-models repository (e.g. third_party/tt_forge_models), remove PyTorch and JAX model tests from forge/test/models, and align ONNX, PaddlePaddle, and TensorFlow tests with a single source of truth for models and inputs.
Context
- tt-forge-onnx focuses on frameworks not covered by tt-xla: ONNX, PaddlePaddle, and TensorFlow.
- tt-forge-models is the shared model repository used across TT-Forge frontend projects and provides a consistent
ModelLoader(and related) interface for loading models and inputs. - Currently, loading strategies are inconsistent:
- ONNX: Some tests use tt-forge-models (e.g. regnet, dla, googlenet, fuyu, clip, roberta, centernet); many others load PyTorch models in-test and use local utils or URLs/checkpoints for weights and inputs.
- PaddlePaddle: Models and inputs are loaded inside the test suite via framework APIs (e.g.
resnet18(pretrained=True),paddle.rand(...)) and local utils, not from tt-forge-models. - TensorFlow: Models and inputs are loaded in-test (e.g. Keras
ResNet50(weights="imagenet")) and localmodel_utils(e.g.get_sample_inputs()), not from tt-forge-models.
- forge/test/models still contains PyTorch and JAX model tests, which are out of scope for tt-forge-onnx (PyTorch/JAX are covered by tt-xla).
Goals
-
ONNX
- Load all models and inputs from tt-forge-models (via its loaders), instead of:
- Loading PyTorch models in-test.
- Downloading
.pt/.pthfrom external URLs. - Using local utils (e.g.
test.models.pytorch.vision.mnist.model_utils.utils) or repo-local files (e.g.forge/test/models/files/samples/audio/).
Note: We do not currently have loaders for the ONNX framework, so for now load PyTorch models from the tt-forge-models repository, convert them to ONNX using the torch.onnx.export API, and then compile and execute them.
- Load all models and inputs from tt-forge-models (via its loaders), instead of:
-
PaddlePaddle
- Load models and inputs from tt-forge-models (using its loaders/utils) instead of loading models and constructing inputs entirely within the test suite (e.g.
eval(variant)(pretrained=True),paddle.rand(...), or local utils).
- Load models and inputs from tt-forge-models (using its loaders/utils) instead of loading models and constructing inputs entirely within the test suite (e.g.
-
TensorFlow
- Load models and inputs from tt-forge-models instead of using framework APIs and local
model_utils(e.g. KerasResNet50(weights="imagenet")andget_sample_inputs()intest.models.tensorflow.vision.resnet.model_utils.image_utils).
- Load models and inputs from tt-forge-models instead of using framework APIs and local
-
Cleanup
- Remove PyTorch model tests under
forge/test/models/pytorch/. - Remove JAX model tests under
forge/test/models/jax/. - Rely on tt-forge-models (and, where applicable, ONNX tests that use it) as the single source for model definitions, weights, and inputs for the frameworks in scope.
- Remove PyTorch model tests under
cc @nvukobratTT
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels