You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
[TTNN.JIT] Add Support for Width and Height Sharded Tensors (#5969)
### Ticket
#5836#5837
### Problem description
TTNN JIT needs to support width and height sharded tensors.
#5437 added the required
support for representing these layouts and generating DMA in D2M. This
PR is focused on the JIT frontend and end to end execution.
### Changes in JIT Frontend
- Emit height sharded and width sharded `TTNNLayoutAttr`s
- Set `exact_grid` on all generated `TTNNLayoutAttr` to avoid grid
collapsing when generating a TTNN flat buffer
- Check for unsupported tensor layouts and raise exceptions
- Centralize duplicate tensor layout translation code between the AST
parser and graph tracer into `tensor_translator.py`
### Changes in Core D2M
- Modified `TTIRToD2M` pass:
- Match height and width sharded tensors and create virtual grids to
represent the shard distribution
- Modified `D2MToTTNN` pass:
- Set the correct grid on `ttnn.generic` ops when the corresponding
`d2m.generic` has a virtual grid
- Include a fix for virtual grids from @bgrady-tt in
36148f7
### Changes to TTNN Dialect and Flatbuffer Generation
The problem:
TTNN JIT needs to represent already fully specified tensors using the
TTNN dialect, while the TTNN compiler gets to pick the properties of the
tensor. This is problematic for height and width sharding since TTNN JIT
needs to represent exact physical grids (e.g. `6x2` and **NOT** `3x4`),
while the TTNN compiler can use a virtual grid (`12x1` or `1x12` in this
case) and **choose** a desired physical collapsing
The solution:
Added an `exact_grid` optional parameter on `TTNNLayoutAttr`. If this
parameter is set to true, the grid is not collapsed and recorded in the
flatbuffer as is.
### Tests
- IR tests for virtual grid and metal layout generation
- Sweeping of all grids in `test_layouts.py`
- Height and width sharded layouts added to all existing tests in
`test_eltwise.py` and `test_eltwise_composite.py`
- Closes#5836
- Closes#5837
### Checklist
- [X] New/Existing tests provide coverage for changes
---------
Co-authored-by: Brett Grady <[email protected]>
AttrParameter<"MemRefType", "A memref that describes the physical footprint allocation of the shard. It must also have a shape with rank equal to grid.">:$memref,
OptionalParameter<"bool", "A status flag, asking the users to ignore the physical layout. This is used to model a sharded layout with unspecified shard shape.">:$ignorePhysicalLayout);
OptionalParameter<"bool", "A status flag, asking the users to ignore the physical layout. This is used to model a sharded layout with unspecified shard shape.">:$ignorePhysicalLayout,
648
+
OptionalParameter<"bool", "A status flag indicating that the grid should be treated as the exact physical grid and not a virtual grid to becollapsed">:$exactGrid);
0 commit comments