Skip to content

Unify model and input loading via tt-forge-models repository #3183

@chandrasekaranpradeep

Description

@chandrasekaranpradeep

Summary

Standardize all model and input loading in tt-forge-onnx to use the tt-forge-models repository (e.g. third_party/tt_forge_models), remove PyTorch and JAX model tests from forge/test/models, and align ONNX, PaddlePaddle, and TensorFlow tests with a single source of truth for models and inputs.

Context

  • tt-forge-onnx focuses on frameworks not covered by tt-xla: ONNX, PaddlePaddle, and TensorFlow.
  • tt-forge-models is the shared model repository used across TT-Forge frontend projects and provides a consistent ModelLoader (and related) interface for loading models and inputs.
  • Currently, loading strategies are inconsistent:
    • ONNX: Some tests use tt-forge-models (e.g. regnet, dla, googlenet, fuyu, clip, roberta, centernet); many others load PyTorch models in-test and use local utils or URLs/checkpoints for weights and inputs.
    • PaddlePaddle: Models and inputs are loaded inside the test suite via framework APIs (e.g. resnet18(pretrained=True), paddle.rand(...)) and local utils, not from tt-forge-models.
    • TensorFlow: Models and inputs are loaded in-test (e.g. Keras ResNet50(weights="imagenet")) and local model_utils (e.g. get_sample_inputs()), not from tt-forge-models.
  • forge/test/models still contains PyTorch and JAX model tests, which are out of scope for tt-forge-onnx (PyTorch/JAX are covered by tt-xla).

Goals

  1. ONNX

    • Load all models and inputs from tt-forge-models (via its loaders), instead of:
      • Loading PyTorch models in-test.
      • Downloading .pt/.pth from external URLs.
      • Using local utils (e.g. test.models.pytorch.vision.mnist.model_utils.utils) or repo-local files (e.g. forge/test/models/files/samples/audio/).
        Note: We do not currently have loaders for the ONNX framework, so for now load PyTorch models from the tt-forge-models repository, convert them to ONNX using the torch.onnx.export API, and then compile and execute them.
  2. PaddlePaddle

    • Load models and inputs from tt-forge-models (using its loaders/utils) instead of loading models and constructing inputs entirely within the test suite (e.g. eval(variant)(pretrained=True), paddle.rand(...), or local utils).
  3. TensorFlow

    • Load models and inputs from tt-forge-models instead of using framework APIs and local model_utils (e.g. Keras ResNet50(weights="imagenet") and get_sample_inputs() in test.models.tensorflow.vision.resnet.model_utils.image_utils).
  4. Cleanup

    • Remove PyTorch model tests under forge/test/models/pytorch/.
    • Remove JAX model tests under forge/test/models/jax/.
    • Rely on tt-forge-models (and, where applicable, ONNX tests that use it) as the single source for model definitions, weights, and inputs for the frameworks in scope.

cc @nvukobratTT

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions