Skip to content

Commit ecc3ca4

Browse files
alnah005aperezvicente-TTgithub-code-quality[bot]Copilot
authored
Add tt_symbiote: PyTorch-to-TTNN transparent acceleration framework (#35699)
### Ticket N/A - New framework introduction ### Problem description PyTorch models running on Tenstorrent hardware require manual conversion of operations to TTNN, creating friction for ML engineers. There's no transparent way to: - Automatically replace PyTorch modules with TTNN equivalents - Handle weight preprocessing and device management - Provide fallback mechanisms when TTNN ops aren't supported - Debug and validate TTNN vs PyTorch execution paths - Selectively accelerate specific model layers This prevents rapid prototyping and requires deep TTNN knowledge for every model port. ### What's changed Introducing [tt_symbiote](https://github.com/tenstorrent/tt-metal/tree/alnah005/tt_symbiote/models/tt_symbiote) - a PyTorch-to-TTNN acceleration framework that transparently replaces torch modules with TTNN implementations while handling device management, weight lifecycle, and automatic fallback. #### Core Arch core/ ├── module.py # TTNNModule base class with auto-fallback ├── tensor.py # TorchTTNNTensor wrapper implementing __torch_dispatch__ ├── dispatchers/ # Operation dispatch layer │ ├── default_dispatcher.py │ ├── debug_dispatcher.py │ └── cpu_dispatcher.py ├── run_config.py # 6 execution modes (NORMAL, DPL, SEL, etc.) └── utils.py # dtype/layout conversions modules/ # 20+ TTNN module implementations ├── linear.py # TTNNLinear + LLaMA variants (bfloat8/16) ├── attention.py # SDPA, fused QKV, Whisper attention ├── conv.py # Conv2d, MaxPool, Upsample, Bottleneck ├── activation.py # Silu, ReLU, Gelu ├── normalization.py # LayerNorm └── tensor.py # Permute, Reshape, Add utils/ ├── module_replacement.py # Recursive module substitution with exclusion └── device_management.py # Device lifecycle and weight movement tests/ # 14 model validation tests #### Key Features 1. Transparent Module Replacement: `register_module_replacement_dict(model, {nn.Linear: TTNNLinear}, exclude_replacement=set(["layer1.0"]))` 2. Automatic Dispatch: `TorchTTNNTensor` wraps tensors, routes ops through `__torch_dispatch__` to TTNN backend 3. Six Run Modes via `TT_SYMBIOTE_RUN_MODE`: - `NORMAL`: Pure TTNN execution - `NORMAL_WITH_FALLBACK`: Auto-fallback on errors - `SEL`: Segment Each Layer - PyTorch takes TTNN tensors, validates with PCC - `DPL`: Debug Per Layer - runs both paths separately, compares - `DPL_NO_ERROR_PROP`: DPL with PyTorch tensors to avoid error propagation - `CPU`: CPU-only mode 4. Weight Lifecycle Management: Explicit preprocess → host → device → deallocate with `@deallocate_weights_after` decorator 5. Three Dispatchers via `TT_SYMBIOTE_DISPATCHER`: DEFAULT, DEBUG (verbose logging), CPU #### Validated Models (14 tests) Vision: ViT, ResNet50, YUNet, OWL-ViT LLMs: LLaMA-3.2-1B, GPT-OSS-20B, GLM-4.5-Air Multi-modal: OpenVLA-7B Speech: Whisper-large-v3, SpeechT5 Video: HunyuanVideo 1.5 #### API Example ``` from transformers import AutoModelForImageClassification from models.tt_symbiote.utils.module_replacement import register_module_replacement_dict from models.tt_symbiote.modules.linear import TTNNLinear model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") register_module_replacement_dict(model, {nn.Linear: TTNNLinear}) ``` ## Compile-time impact: Python-only framework, no C++ changes. Zero impact on tt_metal/ttnn build times. ## ABI/API stability No changes to public tt_metal or ttnn APIs. Pure extension layer. --------- Co-authored-by: Alejandro Perez-Vicente <aperezvicente@tenstorrent.com> Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 694c6ef commit ecc3ca4

38 files changed

+7351
-0
lines changed

.github/CODEOWNERS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ models/experimental/SSD512 @dvartaniansTT @tenstorrent/cse-developer-ttnn @tenst
366366
models/experimental/smolvla @tvardhineniTT @dvartaniansTT @tenstorrent/cse-developer-ttnn @tenstorrent/codeowner-bypass
367367
models/experimental/openvla @tvardhineniTT @dvartaniansTT @tenstorrent/cse-developer-ttnn @tenstorrent/codeowner-bypass
368368
models/experimental/tt_transformers_v2 @gwangTT @yieldthought @uaydonat @tenstorrent/codeowner-bypass
369+
models/experimental/tt_symbiote @alnah005 @mbahnasTT @yieldthought @uaydonat @tenstorrent/codeowner-bypass
369370
models/experimental/retinanet @dvartaniansTT @tenstorrent/cse-developer-ttnn @tenstorrent/codeowner-bypass
370371
models/experimental/efficientdetd0 @dvartaniansTT @tenstorrent/cse-developer-ttnn @tenstorrent/codeowner-bypass
371372
models/**/requirements*.txt @tenstorrent/metalium-developers-infra @tenstorrent/codeowner-bypass

models/experimental/tt_symbiote/ARCHITECTURE.svg

Lines changed: 1 addition & 0 deletions
Loading
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
# TT-Symbiote
2+
3+
PyTorch-to-TTNN acceleration framework for transparent hardware acceleration of neural networks on Tenstorrent devices.
4+
5+
## Overview
6+
7+
TT-Symbiote enables TTNN acceleration of pretrained PyTorch models by replacing standard PyTorch modules (e.g., `nn.Linear`, `nn.LayerNorm`) with TTNN-optimized equivalents. The framework automatically handles:
8+
- Module replacement and weight conversion
9+
- Device management and memory allocation
10+
- Fallback to PyTorch when TTNN operations fail
11+
12+
## Flow Diagram
13+
14+
![Architecture diagram](ARCHITECTURE.svg)
15+
16+
## Run Modes
17+
18+
TT-Symbiote supports multiple execution modes via the `TT_SYMBIOTE_RUN_MODE` environment variable:
19+
20+
- **NORMAL** - Default TTNN execution mode
21+
- **NORMAL_WITH_FALLBACK** - TTNN with automatic PyTorch fallback on errors
22+
- **SEL** - Segment Each Layer mode. Pytorch takes TTNN tensors as input, compares outputs with PCC.
23+
- **DPL** - Debug Per Layer. runs both TTNN and PyTorch separately, compares outputs with PCC.
24+
- **DPL_NO_ERROR_PROP** - DPL but TTNN takes Pytorch tensors as input to avoid error propagation.
25+
- **CPU** - CPU-only execution mode
26+
27+
```bash
28+
# Set execution mode before running
29+
export TT_SYMBIOTE_RUN_MODE=NORMAL && pytest tests/test_resnet50.py
30+
31+
# Or use DPL mode for debugging
32+
export TT_SYMBIOTE_RUN_MODE=DPL && pytest tests/test_vit.py
33+
```
34+
35+
## Dispatcher Configuration
36+
37+
TT-Symbiote supports multiple dispatcher implementations via the `TT_SYMBIOTE_DISPATCHER` environment variable:
38+
39+
- **DEFAULT** - Standard TTNN operation dispatcher (Not set by default. CPU set by default)
40+
- **DEBUG** - Verbose logging dispatcher for debugging
41+
- **CPU** - CPU-only dispatcher for testing
42+
43+
```bash
44+
# Use default dispatcher
45+
export TT_SYMBIOTE_DISPATCHER=DEFAULT
46+
47+
# Use debug dispatcher for verbose operation logging
48+
export TT_SYMBIOTE_DISPATCHER=DEBUG
49+
50+
# Use CPU dispatcher
51+
export TT_SYMBIOTE_DISPATCHER=CPU
52+
53+
# Combine with run mode for advanced debugging
54+
export TT_SYMBIOTE_DISPATCHER=CPU && export TT_SYMBIOTE_RUN_MODE=DPL_NO_ERROR_PROP && pytest tests/test_speech_t5.py
55+
```
56+
57+
## Quick Start
58+
59+
```python
60+
import torch
61+
from torch import nn
62+
from transformers import AutoModelForImageClassification
63+
from models.experimental.tt_symbiote.utils.module_replacement import register_module_replacement_dict
64+
from models.experimental.tt_symbiote.utils.device_management import set_device
65+
from models.experimental.tt_symbiote.modules.linear import TTNNLinear
66+
from models.experimental.tt_symbiote.modules.normalization import TTNNLayerNorm
67+
68+
# Load model
69+
model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224")
70+
71+
# Define module replacement mapping
72+
nn_to_ttnn = {
73+
nn.Linear: TTNNLinear,
74+
nn.LayerNorm: TTNNLayerNorm,
75+
}
76+
77+
# Replace modules and set device
78+
register_module_replacement_dict(model, nn_to_ttnn, model_config=None)
79+
80+
#####
81+
# Get TTNN device
82+
ttnn_device = # Obtain TTNN device (e.g., through pytest fixture or ttnn.CreateDevice)
83+
#####
84+
85+
set_device(model, ttnn_device)
86+
87+
# Run inference
88+
model.eval()
89+
torch.set_grad_enabled(False)
90+
result = model(torch.randn(1, 3, 224, 224))
91+
92+
93+
```
94+
95+
## Selective Module Replacement
96+
97+
You can selectively exclude specific modules from replacement using the `exclude_replacement` parameter:
98+
99+
```python
100+
# Replace all Bottleneck modules except layer1.0
101+
register_module_replacement_dict(
102+
model,
103+
nn_to_ttnn,
104+
model_config={"program_config_ffn": {}},
105+
exclude_replacement=set(["layer1.0"])
106+
)
107+
```
108+
109+
**How it works:**
110+
111+
1. **Initial replacement** - First, run without exclusions to see the module names:
112+
```python
113+
register_module_replacement_dict(model, nn_to_ttnn, exclude_replacement=set([]))
114+
```
115+
116+
2. **Identify module names** - Check the model structure. TTNN modules show their `module_name`:
117+
```python
118+
# layer1.0 is now a TTNNBottleneck with module_name=layer1.0
119+
(0): TTNNBottleneck(module_name=layer1.0
120+
(conv1): TTNNConv2dBNActivationNHWC(...)
121+
(conv2): TTNNConv2dBNActivationNHWC(...)
122+
...
123+
)
124+
(1): TTNNBottleneck(module_name=layer1.1
125+
(conv1): TTNNConv2dBNActivationNHWC(...)
126+
...
127+
)
128+
```
129+
130+
3. **Re-run with exclusions** - Use the module names to exclude specific modules:
131+
```python
132+
# Exclude layer1.0 from replacement - it stays as PyTorch Bottleneck
133+
register_module_replacement_dict(
134+
model,
135+
nn_to_ttnn,
136+
exclude_replacement=set(["layer1.0"])
137+
)
138+
```
139+
140+
**Result:**
141+
```python
142+
# layer1.0 remains as original PyTorch Bottleneck
143+
(0): Bottleneck(
144+
(conv1): Conv2d(64, 64, kernel_size=(1, 1), ...)
145+
(bn1): BatchNorm2d(64, ...)
146+
...
147+
)
148+
# layer1.1 is replaced with TTNN
149+
(1): TTNNBottleneck(module_name=layer1.1
150+
(conv1): TTNNConv2dBNActivationNHWC(...)
151+
...
152+
)
153+
```
154+
155+
This is useful for:
156+
- **Debugging** - Isolate problematic modules
157+
- **Performance tuning** - Compare PyTorch vs TTNN for specific layers
158+
- **Mixed execution** - Run certain layers on CPU/PyTorch while others use TTNN
159+
160+
## Creating a New TTNN Module
161+
162+
All TTNN modules inherit from `TTNNModule` and implement:
163+
164+
```python
165+
from models.experimental.tt_symbiote.core.module import TTNNModule
166+
import ttnn
167+
from torch import nn
168+
169+
class TTNNCustomLayer(TTNNModule):
170+
def __init__(self, param1, param2):
171+
super().__init__()
172+
self.param1 = param1
173+
self.param2 = param2
174+
175+
@classmethod
176+
def from_torch(cls, torch_layer):
177+
"""Create TTNN module from PyTorch equivalent."""
178+
new_layer = TTNNCustomLayer(torch_layer.param1, torch_layer.param2)
179+
new_layer._fallback_torch_layer = torch_layer
180+
return new_layer
181+
182+
def preprocess_weights_impl(self):
183+
"""Convert PyTorch weights to TTNN format (called once)."""
184+
self.tt_weight_host = ttnn.from_torch(
185+
self.torch_layer.weight,
186+
dtype=ttnn.bfloat16,
187+
layout=ttnn.TILE_LAYOUT
188+
)
189+
190+
def move_weights_to_device_impl(self):
191+
"""Move preprocessed weights to device."""
192+
self.tt_weight = ttnn.to_device(self.tt_weight_host, self.device)
193+
194+
def deallocate_weights_impl(self):
195+
"""Deallocate device memory."""
196+
ttnn.deallocate(self.tt_weight)
197+
198+
def forward(self, input_tensor):
199+
"""TTNN forward implementation."""
200+
output = ttnn.custom_op(input_tensor, self.tt_weight)
201+
return output
202+
```
203+
204+
**Key Methods:**
205+
- `from_torch()`: Factory method to create from PyTorch module
206+
- `preprocess_weights_impl()`: Convert weights to TTNN format (runs once)
207+
- `move_weights_to_device_impl()`: Transfer weights to device
208+
- `forward()`: TTNN implementation of the operation
209+
- `deallocate_weights_impl()`: Free device memory
210+
211+
The base class handles:
212+
- Automatic fallback to PyTorch on errors
213+
- Tensor wrapping/unwrapping
214+
- Weight lifecycle management
215+
- Device placement
216+
217+
## Weight Management
218+
219+
The framework provides sophisticated weight lifecycle management:
220+
221+
```python
222+
# Automatic weight preprocessing and device placement
223+
module.preprocess_weights() # Convert PyTorch → TTNN format (once)
224+
module.move_weights_to_device() # Transfer to device
225+
module.deallocate_weights() # Free device memory
226+
227+
# Auto-deallocation decorator for memory-constrained scenarios
228+
from models.experimental.tt_symbiote.core.module import deallocate_weights_after
229+
230+
class TTNNLinearLLama(TTNNLinear):
231+
@deallocate_weights_after
232+
def forward(self, input_tensor):
233+
# Weights automatically deallocated after forward pass
234+
return super().forward(input_tensor)
235+
```
236+
237+
## TorchTTNNTensor
238+
239+
The framework uses a custom tensor wrapper that enables transparent operation dispatch:
240+
241+
```python
242+
from models.experimental.tt_symbiote.core.tensor import TorchTTNNTensor
243+
244+
# Wrap PyTorch tensor for TTNN dispatch
245+
tensor = TorchTTNNTensor(torch.randn(10, 20))
246+
247+
# Access underlying representations
248+
tensor.to_torch # Get PyTorch tensor
249+
tensor.to_ttnn # Get TTNN tensor
250+
251+
# Supports standard operations with automatic dispatch
252+
result = tensor * 2.0 + 3.0 # Dispatched to TTNN backend
253+
```
254+
255+
## Running Tests
256+
257+
Tests work with pytest fixtures for device management:
258+
259+
```bash
260+
pytest tests/test_vit.py # ViT with TTNN Linear and LayerNorm
261+
pytest tests/test_llama.py # LLaMA-3.2-1B-Instruct
262+
pytest tests/test_owl_vit.py # OWL-ViT object detection
263+
pytest tests/test_speech_t5.py # SpeechT5 speech synthesis
264+
pytest tests/test_whisper3.py # Whisper-large-v3 automatic speech recognition
265+
pytest tests/test_resnet50.py # ResNet50 with Conv and Bottleneck
266+
pytest tests/test_conv.py # Standalone Conv2d, Conv+BN, Conv+BN+Activation
267+
pytest tests/test_attention.py # Self-attention module tests
268+
pytest tests/test_yunet.py # YUNet face detection
269+
pytest tests/test_hunyuan_video.py # HunyuanVideo 1.5 text-to-video generation
270+
pytest tests/test_glm.py # GLM-4.5-Air with mesh device support
271+
pytest tests/test_gptoss.py # GPT-OSS-20B model
272+
pytest tests/test_openvla.py # OpenVLA-7B vision-language-action model
273+
pytest tests/test_dpl.py # Debug Per Layer mode test
274+
```
275+
276+
## Architecture
277+
278+
```
279+
core/
280+
├── module.py # TTNNModule base class with auto-fallback
281+
├── tensor.py # TorchTTNNTensor wrapper for PyTorch dispatch
282+
├── dispatcher.py # TTNN operation dispatch handlers
283+
├── torch_dispatcher.py # PyTorch operation dispatch handlers
284+
├── run_config.py # Runtime configuration and mode management
285+
├── utils.py # Utility functions for dtype conversion
286+
└── dispatchers/ # Dispatcher implementations
287+
├── default_dispatcher.py # Default operation dispatcher
288+
├── debug_dispatcher.py # Debug operation dispatcher
289+
├── cpu_dispatcher.py # CPU-only operation dispatcher
290+
└── dispatcher_config.py # Dispatcher configuration
291+
292+
modules/ # TTNN implementations
293+
├── linear.py # TTNNLinear, TTNNLinearLLama, TTNNLinearLLamaBFloat16
294+
├── attention.py # TTNNViTSelfAttention
295+
├── normalization.py # TTNNLayerNorm
296+
├── activation.py # TTNNSilu, TTNNReLU
297+
├── conv.py # TTNNConv2dNHWC, TTNNConv2dBNNHWC, TTNNBottleneck
298+
└── tensor.py # TTNNPermute, TTNNReshape
299+
300+
utils/
301+
├── module_replacement.py # Recursive module swapping
302+
└── device_management.py # Device configuration
303+
```
304+
305+
## Available TTNN Modules
306+
307+
**Linear Layers:**
308+
- `TTNNLinear` - Standard linear layer (bfloat16)
309+
- `TTNNLinearLLama` - Optimized for LLaMA (bfloat8_b, auto-deallocates weights)
310+
- `TTNNLinearLLamaBFloat16` - LLaMA variant with bfloat16
311+
- `TTNNLinearGelu` - Linear layer with fused GELU activation
312+
313+
**Activation Functions:**
314+
- `TTNNSilu` - SiLU/Swish activation
315+
- `TTNNReLU` - ReLU activation
316+
- `TTNNGelu` - GELU activation
317+
318+
**Normalization:**
319+
- `TTNNLayerNorm` - Layer normalization
320+
321+
**Attention:**
322+
- `TTNNViTSelfAttention` - Vision Transformer self-attention (deprecated, use TTNNSelfAttention)
323+
- `TTNNSDPAAttention` - Scaled Dot-Product Attention
324+
- `TTNNFusedQKVSelfAttention` - Self-attention with fused QKV projections
325+
- `TTNNSelfAttention` - General self-attention module
326+
- `TTNNWhisperAttention` - Whisper-specific attention implementation
327+
328+
**Convolution:**
329+
- `TTNNConv2dNHWC` - 2D convolution with NHWC layout
330+
- `TTNNConv2dBNNHWC` - Conv2d fused with BatchNorm
331+
- `TTNNConv2dBNActivationNHWC` - Conv2d + BatchNorm + Activation
332+
- `TTNNBottleneck` - ResNet bottleneck block
333+
- `TTNNMaxPool2dNHWC` - 2D max pooling with NHWC layout
334+
- `TTNNUpsampleNHWC` - Upsampling with NHWC layout
335+
- `TTNNPatchEmbedding` - Vision Transformer patch embedding
336+
- `TTNNViTEmbeddings` - Complete ViT embedding layer (patch + position)
337+
338+
**Tensor Operations:**
339+
- `TTNNPermute` - Tensor permutation
340+
- `TTNNReshape` - Tensor reshaping
341+
- `TTNNAdd` - Element-wise addition
342+
343+
## Examples
344+
345+
See [tests/](tests/) directory:
346+
- [test_vit.py](tests/test_vit.py) - Vision Transformer with TTNN Linear and LayerNorm
347+
- [test_llama.py](tests/test_llama.py) - LLaMA-3.2-1B-Instruct with bfloat8 optimizations
348+
- [test_owl_vit.py](tests/test_owl_vit.py) - OWL-ViT object detection model
349+
- [test_speech_t5.py](tests/test_speech_t5.py) - SpeechT5 speech synthesis model
350+
- [test_whisper3.py](tests/test_whisper3.py) - Whisper-large-v3 automatic speech recognition
351+
- [test_resnet50.py](tests/test_resnet50.py) - ResNet50 with Conv and Bottleneck blocks
352+
- [test_conv.py](tests/test_conv.py) - Standalone Conv2d, Conv+BN, and Conv+BN+Activation tests
353+
- [test_attention.py](tests/test_attention.py) - Self-attention module tests
354+
- [test_yunet.py](tests/test_yunet.py) - YUNet face detection model
355+
- [test_hunyuan_video.py](tests/test_hunyuan_video.py) - HunyuanVideo 1.5 text-to-video generation
356+
- [test_glm.py](tests/test_glm.py) - GLM-4.5-Air with mesh device support
357+
- [test_gptoss.py](tests/test_gptoss.py) - GPT-OSS-20B model
358+
- [test_openvla.py](tests/test_openvla.py) - OpenVLA-7B vision-language-action model
359+
- [test_dpl.py](tests/test_dpl.py) - Debug Per Layer mode test

0 commit comments

Comments
 (0)