|
| 1 | +# TT-Symbiote |
| 2 | + |
| 3 | +PyTorch-to-TTNN acceleration framework for transparent hardware acceleration of neural networks on Tenstorrent devices. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +TT-Symbiote enables TTNN acceleration of pretrained PyTorch models by replacing standard PyTorch modules (e.g., `nn.Linear`, `nn.LayerNorm`) with TTNN-optimized equivalents. The framework automatically handles: |
| 8 | +- Module replacement and weight conversion |
| 9 | +- Device management and memory allocation |
| 10 | +- Fallback to PyTorch when TTNN operations fail |
| 11 | + |
| 12 | +## Flow Diagram |
| 13 | + |
| 14 | + |
| 15 | + |
| 16 | +## Run Modes |
| 17 | + |
| 18 | +TT-Symbiote supports multiple execution modes via the `TT_SYMBIOTE_RUN_MODE` environment variable: |
| 19 | + |
| 20 | +- **NORMAL** - Default TTNN execution mode |
| 21 | +- **NORMAL_WITH_FALLBACK** - TTNN with automatic PyTorch fallback on errors |
| 22 | +- **SEL** - Segment Each Layer mode. Pytorch takes TTNN tensors as input, compares outputs with PCC. |
| 23 | +- **DPL** - Debug Per Layer. runs both TTNN and PyTorch separately, compares outputs with PCC. |
| 24 | +- **DPL_NO_ERROR_PROP** - DPL but TTNN takes Pytorch tensors as input to avoid error propagation. |
| 25 | +- **CPU** - CPU-only execution mode |
| 26 | + |
| 27 | +```bash |
| 28 | +# Set execution mode before running |
| 29 | +export TT_SYMBIOTE_RUN_MODE=NORMAL && pytest tests/test_resnet50.py |
| 30 | + |
| 31 | +# Or use DPL mode for debugging |
| 32 | +export TT_SYMBIOTE_RUN_MODE=DPL && pytest tests/test_vit.py |
| 33 | +``` |
| 34 | + |
| 35 | +## Dispatcher Configuration |
| 36 | + |
| 37 | +TT-Symbiote supports multiple dispatcher implementations via the `TT_SYMBIOTE_DISPATCHER` environment variable: |
| 38 | + |
| 39 | +- **DEFAULT** - Standard TTNN operation dispatcher (Not set by default. CPU set by default) |
| 40 | +- **DEBUG** - Verbose logging dispatcher for debugging |
| 41 | +- **CPU** - CPU-only dispatcher for testing |
| 42 | + |
| 43 | +```bash |
| 44 | +# Use default dispatcher |
| 45 | +export TT_SYMBIOTE_DISPATCHER=DEFAULT |
| 46 | + |
| 47 | +# Use debug dispatcher for verbose operation logging |
| 48 | +export TT_SYMBIOTE_DISPATCHER=DEBUG |
| 49 | + |
| 50 | +# Use CPU dispatcher |
| 51 | +export TT_SYMBIOTE_DISPATCHER=CPU |
| 52 | + |
| 53 | +# Combine with run mode for advanced debugging |
| 54 | +export TT_SYMBIOTE_DISPATCHER=CPU && export TT_SYMBIOTE_RUN_MODE=DPL_NO_ERROR_PROP && pytest tests/test_speech_t5.py |
| 55 | +``` |
| 56 | + |
| 57 | +## Quick Start |
| 58 | + |
| 59 | +```python |
| 60 | +import torch |
| 61 | +from torch import nn |
| 62 | +from transformers import AutoModelForImageClassification |
| 63 | +from models.experimental.tt_symbiote.utils.module_replacement import register_module_replacement_dict |
| 64 | +from models.experimental.tt_symbiote.utils.device_management import set_device |
| 65 | +from models.experimental.tt_symbiote.modules.linear import TTNNLinear |
| 66 | +from models.experimental.tt_symbiote.modules.normalization import TTNNLayerNorm |
| 67 | + |
| 68 | +# Load model |
| 69 | +model = AutoModelForImageClassification.from_pretrained("google/vit-base-patch16-224") |
| 70 | + |
| 71 | +# Define module replacement mapping |
| 72 | +nn_to_ttnn = { |
| 73 | + nn.Linear: TTNNLinear, |
| 74 | + nn.LayerNorm: TTNNLayerNorm, |
| 75 | +} |
| 76 | + |
| 77 | +# Replace modules and set device |
| 78 | +register_module_replacement_dict(model, nn_to_ttnn, model_config=None) |
| 79 | + |
| 80 | +##### |
| 81 | +# Get TTNN device |
| 82 | +ttnn_device = # Obtain TTNN device (e.g., through pytest fixture or ttnn.CreateDevice) |
| 83 | +##### |
| 84 | + |
| 85 | +set_device(model, ttnn_device) |
| 86 | + |
| 87 | +# Run inference |
| 88 | +model.eval() |
| 89 | +torch.set_grad_enabled(False) |
| 90 | +result = model(torch.randn(1, 3, 224, 224)) |
| 91 | + |
| 92 | + |
| 93 | +``` |
| 94 | + |
| 95 | +## Selective Module Replacement |
| 96 | + |
| 97 | +You can selectively exclude specific modules from replacement using the `exclude_replacement` parameter: |
| 98 | + |
| 99 | +```python |
| 100 | +# Replace all Bottleneck modules except layer1.0 |
| 101 | +register_module_replacement_dict( |
| 102 | + model, |
| 103 | + nn_to_ttnn, |
| 104 | + model_config={"program_config_ffn": {}}, |
| 105 | + exclude_replacement=set(["layer1.0"]) |
| 106 | +) |
| 107 | +``` |
| 108 | + |
| 109 | +**How it works:** |
| 110 | + |
| 111 | +1. **Initial replacement** - First, run without exclusions to see the module names: |
| 112 | +```python |
| 113 | +register_module_replacement_dict(model, nn_to_ttnn, exclude_replacement=set([])) |
| 114 | +``` |
| 115 | + |
| 116 | +2. **Identify module names** - Check the model structure. TTNN modules show their `module_name`: |
| 117 | +```python |
| 118 | +# layer1.0 is now a TTNNBottleneck with module_name=layer1.0 |
| 119 | +(0): TTNNBottleneck(module_name=layer1.0 |
| 120 | + (conv1): TTNNConv2dBNActivationNHWC(...) |
| 121 | + (conv2): TTNNConv2dBNActivationNHWC(...) |
| 122 | + ... |
| 123 | +) |
| 124 | +(1): TTNNBottleneck(module_name=layer1.1 |
| 125 | + (conv1): TTNNConv2dBNActivationNHWC(...) |
| 126 | + ... |
| 127 | +) |
| 128 | +``` |
| 129 | + |
| 130 | +3. **Re-run with exclusions** - Use the module names to exclude specific modules: |
| 131 | +```python |
| 132 | +# Exclude layer1.0 from replacement - it stays as PyTorch Bottleneck |
| 133 | +register_module_replacement_dict( |
| 134 | + model, |
| 135 | + nn_to_ttnn, |
| 136 | + exclude_replacement=set(["layer1.0"]) |
| 137 | +) |
| 138 | +``` |
| 139 | + |
| 140 | +**Result:** |
| 141 | +```python |
| 142 | +# layer1.0 remains as original PyTorch Bottleneck |
| 143 | +(0): Bottleneck( |
| 144 | + (conv1): Conv2d(64, 64, kernel_size=(1, 1), ...) |
| 145 | + (bn1): BatchNorm2d(64, ...) |
| 146 | + ... |
| 147 | +) |
| 148 | +# layer1.1 is replaced with TTNN |
| 149 | +(1): TTNNBottleneck(module_name=layer1.1 |
| 150 | + (conv1): TTNNConv2dBNActivationNHWC(...) |
| 151 | + ... |
| 152 | +) |
| 153 | +``` |
| 154 | + |
| 155 | +This is useful for: |
| 156 | +- **Debugging** - Isolate problematic modules |
| 157 | +- **Performance tuning** - Compare PyTorch vs TTNN for specific layers |
| 158 | +- **Mixed execution** - Run certain layers on CPU/PyTorch while others use TTNN |
| 159 | + |
| 160 | +## Creating a New TTNN Module |
| 161 | + |
| 162 | +All TTNN modules inherit from `TTNNModule` and implement: |
| 163 | + |
| 164 | +```python |
| 165 | +from models.experimental.tt_symbiote.core.module import TTNNModule |
| 166 | +import ttnn |
| 167 | +from torch import nn |
| 168 | + |
| 169 | +class TTNNCustomLayer(TTNNModule): |
| 170 | + def __init__(self, param1, param2): |
| 171 | + super().__init__() |
| 172 | + self.param1 = param1 |
| 173 | + self.param2 = param2 |
| 174 | + |
| 175 | + @classmethod |
| 176 | + def from_torch(cls, torch_layer): |
| 177 | + """Create TTNN module from PyTorch equivalent.""" |
| 178 | + new_layer = TTNNCustomLayer(torch_layer.param1, torch_layer.param2) |
| 179 | + new_layer._fallback_torch_layer = torch_layer |
| 180 | + return new_layer |
| 181 | + |
| 182 | + def preprocess_weights_impl(self): |
| 183 | + """Convert PyTorch weights to TTNN format (called once).""" |
| 184 | + self.tt_weight_host = ttnn.from_torch( |
| 185 | + self.torch_layer.weight, |
| 186 | + dtype=ttnn.bfloat16, |
| 187 | + layout=ttnn.TILE_LAYOUT |
| 188 | + ) |
| 189 | + |
| 190 | + def move_weights_to_device_impl(self): |
| 191 | + """Move preprocessed weights to device.""" |
| 192 | + self.tt_weight = ttnn.to_device(self.tt_weight_host, self.device) |
| 193 | + |
| 194 | + def deallocate_weights_impl(self): |
| 195 | + """Deallocate device memory.""" |
| 196 | + ttnn.deallocate(self.tt_weight) |
| 197 | + |
| 198 | + def forward(self, input_tensor): |
| 199 | + """TTNN forward implementation.""" |
| 200 | + output = ttnn.custom_op(input_tensor, self.tt_weight) |
| 201 | + return output |
| 202 | +``` |
| 203 | + |
| 204 | +**Key Methods:** |
| 205 | +- `from_torch()`: Factory method to create from PyTorch module |
| 206 | +- `preprocess_weights_impl()`: Convert weights to TTNN format (runs once) |
| 207 | +- `move_weights_to_device_impl()`: Transfer weights to device |
| 208 | +- `forward()`: TTNN implementation of the operation |
| 209 | +- `deallocate_weights_impl()`: Free device memory |
| 210 | + |
| 211 | +The base class handles: |
| 212 | +- Automatic fallback to PyTorch on errors |
| 213 | +- Tensor wrapping/unwrapping |
| 214 | +- Weight lifecycle management |
| 215 | +- Device placement |
| 216 | + |
| 217 | +## Weight Management |
| 218 | + |
| 219 | +The framework provides sophisticated weight lifecycle management: |
| 220 | + |
| 221 | +```python |
| 222 | +# Automatic weight preprocessing and device placement |
| 223 | +module.preprocess_weights() # Convert PyTorch → TTNN format (once) |
| 224 | +module.move_weights_to_device() # Transfer to device |
| 225 | +module.deallocate_weights() # Free device memory |
| 226 | + |
| 227 | +# Auto-deallocation decorator for memory-constrained scenarios |
| 228 | +from models.experimental.tt_symbiote.core.module import deallocate_weights_after |
| 229 | + |
| 230 | +class TTNNLinearLLama(TTNNLinear): |
| 231 | + @deallocate_weights_after |
| 232 | + def forward(self, input_tensor): |
| 233 | + # Weights automatically deallocated after forward pass |
| 234 | + return super().forward(input_tensor) |
| 235 | +``` |
| 236 | + |
| 237 | +## TorchTTNNTensor |
| 238 | + |
| 239 | +The framework uses a custom tensor wrapper that enables transparent operation dispatch: |
| 240 | + |
| 241 | +```python |
| 242 | +from models.experimental.tt_symbiote.core.tensor import TorchTTNNTensor |
| 243 | + |
| 244 | +# Wrap PyTorch tensor for TTNN dispatch |
| 245 | +tensor = TorchTTNNTensor(torch.randn(10, 20)) |
| 246 | + |
| 247 | +# Access underlying representations |
| 248 | +tensor.to_torch # Get PyTorch tensor |
| 249 | +tensor.to_ttnn # Get TTNN tensor |
| 250 | + |
| 251 | +# Supports standard operations with automatic dispatch |
| 252 | +result = tensor * 2.0 + 3.0 # Dispatched to TTNN backend |
| 253 | +``` |
| 254 | + |
| 255 | +## Running Tests |
| 256 | + |
| 257 | +Tests work with pytest fixtures for device management: |
| 258 | + |
| 259 | +```bash |
| 260 | +pytest tests/test_vit.py # ViT with TTNN Linear and LayerNorm |
| 261 | +pytest tests/test_llama.py # LLaMA-3.2-1B-Instruct |
| 262 | +pytest tests/test_owl_vit.py # OWL-ViT object detection |
| 263 | +pytest tests/test_speech_t5.py # SpeechT5 speech synthesis |
| 264 | +pytest tests/test_whisper3.py # Whisper-large-v3 automatic speech recognition |
| 265 | +pytest tests/test_resnet50.py # ResNet50 with Conv and Bottleneck |
| 266 | +pytest tests/test_conv.py # Standalone Conv2d, Conv+BN, Conv+BN+Activation |
| 267 | +pytest tests/test_attention.py # Self-attention module tests |
| 268 | +pytest tests/test_yunet.py # YUNet face detection |
| 269 | +pytest tests/test_hunyuan_video.py # HunyuanVideo 1.5 text-to-video generation |
| 270 | +pytest tests/test_glm.py # GLM-4.5-Air with mesh device support |
| 271 | +pytest tests/test_gptoss.py # GPT-OSS-20B model |
| 272 | +pytest tests/test_openvla.py # OpenVLA-7B vision-language-action model |
| 273 | +pytest tests/test_dpl.py # Debug Per Layer mode test |
| 274 | +``` |
| 275 | + |
| 276 | +## Architecture |
| 277 | + |
| 278 | +``` |
| 279 | +core/ |
| 280 | +├── module.py # TTNNModule base class with auto-fallback |
| 281 | +├── tensor.py # TorchTTNNTensor wrapper for PyTorch dispatch |
| 282 | +├── dispatcher.py # TTNN operation dispatch handlers |
| 283 | +├── torch_dispatcher.py # PyTorch operation dispatch handlers |
| 284 | +├── run_config.py # Runtime configuration and mode management |
| 285 | +├── utils.py # Utility functions for dtype conversion |
| 286 | +└── dispatchers/ # Dispatcher implementations |
| 287 | + ├── default_dispatcher.py # Default operation dispatcher |
| 288 | + ├── debug_dispatcher.py # Debug operation dispatcher |
| 289 | + ├── cpu_dispatcher.py # CPU-only operation dispatcher |
| 290 | + └── dispatcher_config.py # Dispatcher configuration |
| 291 | + |
| 292 | +modules/ # TTNN implementations |
| 293 | +├── linear.py # TTNNLinear, TTNNLinearLLama, TTNNLinearLLamaBFloat16 |
| 294 | +├── attention.py # TTNNViTSelfAttention |
| 295 | +├── normalization.py # TTNNLayerNorm |
| 296 | +├── activation.py # TTNNSilu, TTNNReLU |
| 297 | +├── conv.py # TTNNConv2dNHWC, TTNNConv2dBNNHWC, TTNNBottleneck |
| 298 | +└── tensor.py # TTNNPermute, TTNNReshape |
| 299 | + |
| 300 | +utils/ |
| 301 | +├── module_replacement.py # Recursive module swapping |
| 302 | +└── device_management.py # Device configuration |
| 303 | +``` |
| 304 | + |
| 305 | +## Available TTNN Modules |
| 306 | + |
| 307 | +**Linear Layers:** |
| 308 | +- `TTNNLinear` - Standard linear layer (bfloat16) |
| 309 | +- `TTNNLinearLLama` - Optimized for LLaMA (bfloat8_b, auto-deallocates weights) |
| 310 | +- `TTNNLinearLLamaBFloat16` - LLaMA variant with bfloat16 |
| 311 | +- `TTNNLinearGelu` - Linear layer with fused GELU activation |
| 312 | + |
| 313 | +**Activation Functions:** |
| 314 | +- `TTNNSilu` - SiLU/Swish activation |
| 315 | +- `TTNNReLU` - ReLU activation |
| 316 | +- `TTNNGelu` - GELU activation |
| 317 | + |
| 318 | +**Normalization:** |
| 319 | +- `TTNNLayerNorm` - Layer normalization |
| 320 | + |
| 321 | +**Attention:** |
| 322 | +- `TTNNViTSelfAttention` - Vision Transformer self-attention (deprecated, use TTNNSelfAttention) |
| 323 | +- `TTNNSDPAAttention` - Scaled Dot-Product Attention |
| 324 | +- `TTNNFusedQKVSelfAttention` - Self-attention with fused QKV projections |
| 325 | +- `TTNNSelfAttention` - General self-attention module |
| 326 | +- `TTNNWhisperAttention` - Whisper-specific attention implementation |
| 327 | + |
| 328 | +**Convolution:** |
| 329 | +- `TTNNConv2dNHWC` - 2D convolution with NHWC layout |
| 330 | +- `TTNNConv2dBNNHWC` - Conv2d fused with BatchNorm |
| 331 | +- `TTNNConv2dBNActivationNHWC` - Conv2d + BatchNorm + Activation |
| 332 | +- `TTNNBottleneck` - ResNet bottleneck block |
| 333 | +- `TTNNMaxPool2dNHWC` - 2D max pooling with NHWC layout |
| 334 | +- `TTNNUpsampleNHWC` - Upsampling with NHWC layout |
| 335 | +- `TTNNPatchEmbedding` - Vision Transformer patch embedding |
| 336 | +- `TTNNViTEmbeddings` - Complete ViT embedding layer (patch + position) |
| 337 | + |
| 338 | +**Tensor Operations:** |
| 339 | +- `TTNNPermute` - Tensor permutation |
| 340 | +- `TTNNReshape` - Tensor reshaping |
| 341 | +- `TTNNAdd` - Element-wise addition |
| 342 | + |
| 343 | +## Examples |
| 344 | + |
| 345 | +See [tests/](tests/) directory: |
| 346 | +- [test_vit.py](tests/test_vit.py) - Vision Transformer with TTNN Linear and LayerNorm |
| 347 | +- [test_llama.py](tests/test_llama.py) - LLaMA-3.2-1B-Instruct with bfloat8 optimizations |
| 348 | +- [test_owl_vit.py](tests/test_owl_vit.py) - OWL-ViT object detection model |
| 349 | +- [test_speech_t5.py](tests/test_speech_t5.py) - SpeechT5 speech synthesis model |
| 350 | +- [test_whisper3.py](tests/test_whisper3.py) - Whisper-large-v3 automatic speech recognition |
| 351 | +- [test_resnet50.py](tests/test_resnet50.py) - ResNet50 with Conv and Bottleneck blocks |
| 352 | +- [test_conv.py](tests/test_conv.py) - Standalone Conv2d, Conv+BN, and Conv+BN+Activation tests |
| 353 | +- [test_attention.py](tests/test_attention.py) - Self-attention module tests |
| 354 | +- [test_yunet.py](tests/test_yunet.py) - YUNet face detection model |
| 355 | +- [test_hunyuan_video.py](tests/test_hunyuan_video.py) - HunyuanVideo 1.5 text-to-video generation |
| 356 | +- [test_glm.py](tests/test_glm.py) - GLM-4.5-Air with mesh device support |
| 357 | +- [test_gptoss.py](tests/test_gptoss.py) - GPT-OSS-20B model |
| 358 | +- [test_openvla.py](tests/test_openvla.py) - OpenVLA-7B vision-language-action model |
| 359 | +- [test_dpl.py](tests/test_dpl.py) - Debug Per Layer mode test |
0 commit comments