Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/copy_lock_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def compute_func():
out_block.store(result)

# finalize push, this advances the cb pointers, the writing happened at the line above
out_cb.push()
a_in_cb.pop()
b_in_cb.pop()
out_block.push()
a_block.pop()
b_block.pop()

@ttl.datamovement()
def dm0():
Expand All @@ -87,11 +87,11 @@ def dm0():
# INTENTIONAL ERROR: Attempting to write to a_block before tx.wait()
a_block.store([None, None]) # This should trigger a copy lock error
tx.wait()
a_in_cb.push()
a_block.push()
b_block = b_in_cb.reserve()
tx = ttl.copy(b_in[row_slice, col_slice], b_block)
tx.wait()
b_in_cb.push()
b_block.push()

@ttl.datamovement()
def dm1():
Expand All @@ -111,7 +111,7 @@ def dm1():
tx = ttl.copy(out_block, out[row_slice, col_slice])

tx.wait()
out_cb.pop()
out_block.pop()


def main() -> None:
Expand Down
14 changes: 7 additions & 7 deletions examples/eltwise_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def compute_func():
out_block.store(result)

# finalize push, this advances the cb pointers, the writing happened at the line above
out_cb.push()
out_block.push()
# finalize pop, this advances the cb pointers, essentially freeing the memory
# After poping, the corresponding Block(a_block) points to stale data. Should probably make it an error to access it at that point
a_in_cb.pop()
a_block.pop()
# ditto
b_in_cb.pop()
c_in_cb.pop()
b_block.pop()
c_block.pop()

@ttl.datamovement()
def dm0():
Expand Down Expand Up @@ -136,11 +136,11 @@ def pipe_dst(pipe_id):
a_block = a_in_cb.reserve()
tx = ttl.copy(a_in[row_slice, col_slice], a_block)
tx.wait()
a_in_cb.push()
a_block.push()
b_block = b_in_cb.reserve()
tx = ttl.copy(b_in[row_slice, col_slice], b_block)
tx.wait()
b_in_cb.push()
b_block.push()

@ttl.datamovement()
def dm1():
Expand All @@ -160,7 +160,7 @@ def dm1():

tx = ttl.copy(out_block, out[row_slice, col_slice])
tx.wait()
out_cb.pop()
out_block.pop()


def main() -> None:
Expand Down
14 changes: 7 additions & 7 deletions examples/eltwise_pipe_core3.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,13 @@ def compute_func():
out_block.store(result)

# finalize push, this advances the cb pointers, the writing happened at the line above
out_cb.push()
out_block.push()
# finalize pop, this advances the cb pointers, essentially freeing the memory
# After poping, the corresponding Block(a_block) points to stale data. Should probably make it an error to access it at that point
a_in_cb.pop()
a_block.pop()
# ditto
b_in_cb.pop()
c_in_cb.pop()
b_block.pop()
c_block.pop()

@ttl.datamovement()
def dm0():
Expand Down Expand Up @@ -136,11 +136,11 @@ def pipe_dst(pipe_id):
a_block = a_in_cb.reserve()
tx = ttl.copy(a_in[row_slice, col_slice], a_block)
tx.wait()
a_in_cb.push()
a_block.push()
b_block = b_in_cb.reserve()
tx = ttl.copy(b_in[row_slice, col_slice], b_block)
tx.wait()
b_in_cb.push()
b_block.push()

@ttl.datamovement()
def dm1():
Expand All @@ -160,7 +160,7 @@ def dm1():

tx = ttl.copy(out_block, out[row_slice, col_slice])
tx.wait()
out_cb.pop()
out_block.pop()


def main() -> None:
Expand Down
4 changes: 2 additions & 2 deletions python/sim/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ class ExpectedOp(Enum):
COPY_SRC = auto() # Expect copy(blk, ...) - block as source
COPY_DST = auto() # Expect copy(..., blk) - block as destination
TX_WAIT = auto() # Expect tx.wait()
PUSH = auto() # Expect cb.push()
POP = auto() # Expect cb.pop()
PUSH = auto() # Expect blk.push()
POP = auto() # Expect blk.pop()
STORE = (
auto()
) # Expect blk.store(...) - block as destination, regular store (acc=False)
Expand Down
26 changes: 16 additions & 10 deletions python/sim/cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def __truediv__(self, other: "Block") -> "Block":
def __matmul__(self, other: "Block") -> "Block":
return self._block.__matmul__(other)

def pop(self) -> None:
self._cb.pop_block()

def push(self) -> None:
self._cb.push_block()


class ReserveContext(_BlockContextManager):
"""Context manager for reserve operations that automatically pushes on exit.
Expand All @@ -130,11 +136,11 @@ class ReserveContext(_BlockContextManager):
Or without (for backward compatibility):
blk = cb.reserve()
blk.store(data)
cb.push() # manual push required
blk.push() # manual push required
"""

def __init__(self, cb: "CircularBuffer", block: Block):
super().__init__(cb, block, cb.push)
super().__init__(cb, block, cb.push_block)


class WaitContext(_BlockContextManager):
Expand All @@ -147,11 +153,11 @@ class WaitContext(_BlockContextManager):
Or without (for backward compatibility):
blk = cb.wait()
data = blk[0]
cb.pop() # manual pop required
blk.pop() # manual pop required
"""

def __init__(self, cb: "CircularBuffer", block: Block):
super().__init__(cb, block, cb.pop)
super().__init__(cb, block, cb.pop_block)


# TODO: Should this class now be private?
Expand All @@ -172,12 +178,12 @@ class CircularBuffer:
# Producer workflow
write_view = cb.reserve() # Reserve space for 6 tiles
# ... write data to write_view ...
cb.push() # Make data visible
write_view.push() # Make data visible

# Consumer workflow
read_view = cb.wait() # Wait for 6 tiles
# ... read data from read_view ...
cb.pop() # Free consumed tiles
read_view.pop() # Free consumed tiles
"""

def __init__(
Expand Down Expand Up @@ -260,7 +266,7 @@ def wait(self) -> WaitContext:
Or without (for backward compatibility):
blk = cb.wait()
data = blk[0]
cb.pop() # manual pop required
blk.pop() # manual pop required

Returns:
Context manager providing read access to the available tiles
Expand Down Expand Up @@ -322,7 +328,7 @@ def reserve(self) -> ReserveContext:
Or without (for backward compatibility):
blk = cb.reserve()
blk.store(data)
cb.push() # manual push required
blk.push() # manual push required

Returns:
Context manager providing write access to the reserved space
Expand Down Expand Up @@ -372,7 +378,7 @@ def can_reserve(self) -> bool:
stats = api.cb_stats(cb_id)
return stats.free >= self._tiles_per_operation

def push(self) -> None:
def push_block(self) -> None:
"""
Finalize a write operation, making reserved data visible to consumers.

Expand All @@ -393,7 +399,7 @@ def push(self) -> None:
api, cb_id = self._ensure_initialized()
api.cb_push_back(cb_id, self._tiles_per_operation)

def pop(self) -> None:
def pop_block(self) -> None:
"""
Finalize a read operation, freeing consumed data.

Expand Down
40 changes: 6 additions & 34 deletions python/ttl/circular_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,28 +89,14 @@ def wait(ast_self: "CircularBuffer") -> "TensorBlock":
TensorBlock: The acquired data with CB association.

Example:
shard = cb.wait()
result = compute(shard)
cb.pop()
block = cb.wait()
result = compute(block)
block.pop()
"""
tensor_type = _get_cb_tensor_type(ast_self)
tensor = ttl.cb_wait(tensor_type, ast_self)
return ttl.attach_cb(tensor.type, tensor, ast_self)

def pop(ast_self: "CircularBuffer") -> None:
"""
Signal that data has been consumed (consumer release).

Use in consumer threads after wait() to signal that data has been
consumed and space is available for producers.

Example:
shard = cb.wait()
result = compute(shard)
cb.pop() # Signal consumption complete
"""
ttl.cb_pop(ast_self)

def reserve(ast_self: "CircularBuffer") -> "TensorBlock":
"""
Reserve space in the circular buffer (producer acquire).
Expand All @@ -122,28 +108,14 @@ def reserve(ast_self: "CircularBuffer") -> "TensorBlock":
TensorBlock: The reserved space with CB association.

Example:
cb.reserve()
copy(stream[idx], cb).wait()
cb.push()
block = cb.reserve()
copy(stream[idx], block).wait()
block.push()
"""
tensor_type = _get_cb_tensor_type(ast_self)
tensor = ttl.cb_reserve(tensor_type, ast_self)
return ttl.attach_cb(tensor.type, tensor, ast_self)

def push(ast_self: "CircularBuffer") -> None:
"""
Signal that data is ready in the circular buffer (producer release).

Use in producer threads after reserve() to signal that data has been
written and is ready for consumers.

Example:
shard = cb.reserve()
copy(stream[idx], shard).wait()
cb.push() # Signal data ready
"""
ttl.cb_push(ast_self)


def make_circular_buffer_like(
tensor: Any,
Expand Down
48 changes: 47 additions & 1 deletion python/ttl/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,55 @@ def store(ast_self: TensorBlock, rhs: TensorBlock) -> None:
"""Store result tensor to CB by propagating CB association from output view."""
# ast_self is the result of attach_cb(tensor, cb) from reserve()
# Extract the CB operand and attach it to the result tensor
cb = ast_self.owner.operands[1]
if not _is_block(ast_self):
raise ValueError(
"store() must be called on a block acquired from reserve(), not a regular tensor"
)
cb = _get_cb_from_block(ast_self)
return ttl.attach_cb(rhs.type, rhs, cb)

def push(ast_self: TensorBlock) -> None:
"""
Signal that data is ready in the circular buffer (producer release).

Finalizes a reserve() operation by signaling that the block has been
written and is ready for consumers. This operation is non-blocking.

Must be called on a block acquired via reserve().

Example:
block = cb.reserve()
ttl.copy(data, block).wait()
block.push() # Signal data ready
"""
if not _is_block(ast_self):
raise ValueError(
"push() must be called on a block acquired from reserve(), not a regular tensor"
)
cb = _get_cb_from_block(ast_self)
ttl.cb_push(cb)

def pop(ast_self: TensorBlock) -> None:
"""
Signal that data has been consumed (consumer release).

Finalizes a wait() operation by signaling that the block has been
consumed and space is available for producers. This operation is non-blocking.

Must be called on a block acquired via wait().

Example:
block = cb.wait()
result = compute(block)
block.pop() # Signal consumption complete
"""
if not _is_block(ast_self):
raise ValueError(
"pop() must be called on a block acquired from wait(), not a regular tensor"
)
cb = _get_cb_from_block(ast_self)
ttl.cb_pop(cb)


@syntax("!ttl.transfer_handle")
class CopyTransferHandler:
Expand Down
8 changes: 4 additions & 4 deletions test/python/debug_locations.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,22 @@ def compute_thread():
l = lhs_cb.wait()
o = out_cb.reserve()
o.store(l)
lhs_cb.pop()
out_cb.push()
l.pop()
o.push()

@ttl.datamovement()
def dm_read():
lhs_blk = lhs_cb.reserve()
tx = ttl.copy(lhs[0, 0], lhs_blk)
tx.wait()
lhs_cb.push()
lhs_blk.push()

@ttl.datamovement()
def dm_write():
out_blk = out_cb.wait()
tx = ttl.copy(out_blk, out[0, 0])
tx.wait()
out_cb.pop()
out_blk.pop()


# Verify function definitions exist
Expand Down
Loading