Skip to content

[serve.llm] Prefix-aware scheduler [1/N] Adding Prefix-aware tree data structure #52747

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
379 changes: 379 additions & 0 deletions python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,379 @@
from ray import serve
import time
from threading import RLock
from typing import Optional, List, Tuple, Dict, Set


class Node:
"""
Node in a prefix tree that tracks tenant access time.
Each node represents a segment of text and can belong to multiple tenants.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you show example of a tree representation with this node structure?


def __init__(self, text: str = "", parent: Optional["Node"] = None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can empty string be a valid text? Maybe we should use None to representing initial case?

self.text: str = text
self.parent: Optional["Node"] = parent
self.children: Dict[str, "Node"] = {} # Maps char -> Node
self.tenant_last_access_time: Dict[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a "tenant"? Is it necessary to store all their access time in here? Perhaps this should be one single latest access time for the current node?

str, int
] = {} # Maps tenant -> timestamp in ms (int)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

q: Shouldn't this be a heap for optimized access for LRU eviction policies?


def to_string(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method doesn't to be used? Is it necessary?

return f"Node(text='{self.text}', parent={self.parent}, children={self.children}, tenant_last_access_time={self.tenant_last_access_time})"


@serve.deployment(name="TreeDeployment")
class PrefixTree:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking about this again, I don't think this PrefixTree needs to be a serve deployment since we are not planning on scaling it. It is a single actor shared between many processes basically. We can make it a remote actor that is accessible by name via .get_actor method. This will also reduce the remote call overhead I think because there would not be the extra layers of replica scheduler, etc for this deployment

"""
Thread-safe multi-tenant prefix tree (approximate radix tree).
Features:
1. Stores data for multiple tenants in the same tree structure
2. Node-level locking for concurrent access
3. Leaf LRU eviction based on tenant access time
"""

def __init__(self) -> None:
self.lock: RLock = RLock()
self.root: Node = Node()
self.tenants: Set[str] = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment on what each one of these .tenants, . tenant_char_count, tenant_nodes mean?

self.tenant_char_count: Dict[str, int] = {}
self.tenant_nodes: Dict[str, Set[Node]] = {}

def reset(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what occasion would this reset be called? Only in test? If som let's rename to _reset and add a comment this is only used for tests

"""Reset the tree to an empty state."""
with self.lock:
self.root = Node()
self.tenants = set()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to clarify what a tenant is? What does it mean to have a prefix_tree for multiple tenants.

My assumption is that tenant is basically the replica_id.

self.tenant_char_count = {}
self.tenant_nodes = {}

def to_dict(self) -> Dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this if it's only used for test let's add a comment and apply the following

Suggested change
def to_dict(self) -> Dict:
def _to_dict(self) -> Dict[str, Any]:

return {
"root": self.root,
"tenants": self.tenants,
"tenant_char_count": self.tenant_char_count,
"tenant_nodes": self.tenant_nodes,
}

def to_string(self) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If not used, then let's remove it.

return f"PrefixTree(root={self.root.__str__()}, tenants={self.tenants}, tenant_char_count={self.tenant_char_count}, tenant_nodes={self.tenant_nodes})"

@staticmethod
def shared_prefix_count(a: str, b: str) -> int:
"""Count the number of shared characters at the beginning of two strings."""
i: int = 0
for char_a, char_b in zip(a, b):
if char_a == char_b:
i += 1
else:
break
return i

# def insert(self, text: str, tenant: str) -> Node:
# """Insert text into tree with given tenant. Returns the node that was inserted (or the existing node if it was updated)."""
# with self.lock:
# if tenant not in self.tenants:
# raise ValueError(f"Cannot insert text for tenant '{tenant}': tenant does not exist")

# curr_node: Node = self.root
# timestamp_ms: int = int(time.time() * 1000)
# i: int = 0
# while i < len(text):
# self.tenant_nodes[tenant].add(curr_node)

# first_char: str = text[i]
# curr_text: str = text[i:]
# if first_char not in curr_node.children:
# # No match, create new node
# # e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")}
# new_node: Node = Node(text=curr_text, parent=curr_node)
# new_node.tenant_last_access_time[tenant] = timestamp_ms
# curr_node.children[first_char] = new_node

# # Match found, check if need to split
# matched_node: Node = curr_node.children[first_char]
# shared_count: int = self.shared_prefix_count(
# matched_node.text, curr_text
# )
# if shared_count == len(matched_node.text):
# # Full match, move down the tree
# if tenant not in matched_node.tenant_last_access_time:
# self.tenant_char_count[tenant] += shared_count
# matched_node.tenant_last_access_time[tenant] = timestamp_ms
# curr_node = matched_node
# else:
# # Partial match, split at matched point
# matched_text: str = matched_node.text[:shared_count]
# remaining_text: str = matched_node.text[shared_count:]
# new_parent: Node = Node(text=matched_text, parent=curr_node)
# matched_node.text = remaining_text
# matched_node.parent = new_parent
# new_parent.children[remaining_text[0]] = matched_node
# if tenant not in new_parent.tenant_last_access_time:
# self.tenant_char_count[tenant] += shared_count
# new_parent.tenant_last_access_time[tenant] = timestamp_ms
# curr_node = new_parent

# i += shared_count

# self.tenant_nodes[tenant].add(curr_node)
# return curr_node

def insert(self, text: str, tenant: str) -> Node:
"""Insert text into tree with given tenant. Returns the node that was inserted (or the existing node if it was updated)."""
with self.lock:
if tenant not in self.tenants:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to make add_tenant a private function and then during insertion of tenant does not exist, you simply add it?

raise ValueError(
f"Cannot insert text for tenant '{tenant}': tenant does not exist"
)

curr_node: Node = self.root
timestamp_ms: int = int(time.time() * 1000)
i: int = 0
while i < len(text):
curr_node.tenant_last_access_time[tenant] = timestamp_ms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was there a reason we can't just use float for this?

self.tenant_nodes[tenant].add(curr_node)

first_char: str = text[i]
curr_text: str = text[i:]
if first_char not in curr_node.children:
# No match, create new node
# e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")}
new_node: Node = Node(text=curr_text, parent=curr_node)
new_node.tenant_last_access_time[tenant] = timestamp_ms

# Increment char count for tenant and add node to tenant_nodes
self.tenant_char_count[tenant] += len(curr_text)
self.tenant_nodes[tenant].add(new_node)

curr_node.children[first_char] = new_node
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when this no match condition happens and you end up adding a new node to the current node, isn't it better to break out of the loop? Why do you need to continue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think just missed a return new_node here

else:
# Match found, check if need to split
matched_node: Node = curr_node.children[first_char]
shared_count: int = self.shared_prefix_count(
matched_node.text, curr_text
)

if shared_count < len(matched_node.text):
# Partial match, split at matched point
# Example:
## Before update:
### curr_node.children = {"h": Node("helloworld")}, curr_text = "hellothere" -> shared_count = 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great explanation

### matched_node = Node("helloworld")

## During update:
### Increment tenant_char_count[tenant] by shared_count if matched_node has not seen this tenant before

## After update:
### curr_node.children = {"h": Node("hello", children = {"w": Node("world")})}
### parent_node = Node("hello"), matched_node = Node("world")
### Update tenant_last_access_time for parent_node, NOT matched_node
### (new) curr_text = "there", (new) curr_node = parent_node
### Continue adding "there" to tree in next iteration

matched_text: str = matched_node.text[:shared_count]
remaining_text: str = matched_node.text[shared_count:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this ever be an empty string?


# Update tenant char count for the new split node
if tenant not in matched_node.tenant_last_access_time:
self.tenant_char_count[tenant] += shared_count

# Create new parent node
new_parent: Node = Node(text=matched_text, parent=curr_node)
new_parent.tenant_last_access_time = (
matched_node.tenant_last_access_time.copy()
)
# Update matched_node
matched_node.text = remaining_text
matched_node.parent = new_parent

# Connect new parent node to matched_node
new_parent.children[remaining_text[0]] = matched_node

# Connect current node to new parent
curr_node.children[first_char] = new_parent

# Move down the tree
curr_node = new_parent
i += shared_count
else:
# Full match

# Update tenant char count if this is a new tenant for this node
if tenant not in matched_node.tenant_last_access_time:
self.tenant_char_count[tenant] += shared_count

# # Update tenant last access time
# matched_node.tenant_last_access_time[tenant] = timestamp_ms
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's up with this commented code?


# Move down the tree
curr_node = matched_node
i += shared_count
curr_node.tenant_last_access_time[tenant] = timestamp_ms
self.tenant_nodes[tenant].add(curr_node)
return curr_node

def prefix_match(
self, text: str, available_tenants: Optional[List[str]] = None
) -> Tuple[str, Optional[List[str]]]:
"""
Match text against tree and return (matched_text, matched_tenants).
Does not update access time for the matched tenants (only updates when insert() is called).
If available_tenants is not provided, all tenants are considered.
"""
if available_tenants:
# Filter available_tenants to only include those that exist in the tree
available_tenants = [
tenant for tenant in available_tenants if tenant in self.tenants
]
if not available_tenants:
return "", None
else:
available_tenants = list(self.tenants)

with self.lock:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm wouldn't this prevent multiple prefix_match() calls from happening at the same time? As well as it will wait for insert to complete before allowing this call? I think we should probably go with Cody's idea of having two trees, one for query and another for update and switch the two at a fixed period?

But speaking of which was the performance not impacted with all those locks blocking each other when you run the benchmark??

curr_node: Node = self.root
i: int = 0
text_len: int = len(text)

while i < text_len:
first_char: str = text[i]
curr_text: str = text[i:]

if first_char in curr_node.children:
matched_node: Node = curr_node.children[first_char]

# Check if any of the available tenants match this node
if not any(
tenant in matched_node.tenant_last_access_time
for tenant in available_tenants
):
break

shared_count: int = self.shared_prefix_count(
matched_node.text, curr_text
)
i += shared_count
curr_node = matched_node

if shared_count < len(matched_node.text):
# Partial match, stop here
break
else:
# No match found, stop here
break

# Select the tenants in available_tenants that are in the current node
selected_tenants: Optional[List[str]] = None
matching_tenants = [
tenant
for tenant in available_tenants
if tenant in curr_node.tenant_last_access_time
]
if matching_tenants:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if statement is unnecessary, we can just compile selected_tenants directly from the list comprehension above right?

selected_tenants = matching_tenants

ret_text: str = text[:i]
return ret_text, selected_tenants

def add_tenant(self, tenant: str) -> None:
"""Add a tenant to the tree."""
with self.lock:
if tenant in self.tenants:
raise ValueError(f"Cannot add tenant '{tenant}': tenant already exists")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you want to really raise an error here? or is it better to make this idempotent? meaning that if you add the same tenant it basically becomes a no-op (maybe with a warning log)


self.tenants.add(tenant)
self.tenant_char_count[tenant] = 0
self.tenant_nodes[tenant] = set()

def remove_tenant(self, tenant: str) -> int:
"""Remove a tenant's nodes from the tree, returns the number of characters removed. Also removes the tenant from tenants, tenant_char_count, and tenant_nodes."""
with self.lock:
if tenant not in self.tenants:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to add_tenant, we probably don't want to raise error here and just want to make sure the call is idempotent

f"Cannot remove tenant '{tenant}': tenant does not exist"
)

total_chars_removed: int = 0
for node in self.tenant_nodes[tenant].copy():
total_chars_removed += self.remove_tenant_single_node(tenant, node)

self.tenants.remove(tenant)
self.tenant_nodes.pop(tenant, None)
self.tenant_char_count.pop(tenant, None)
return total_chars_removed

def remove_tenant_single_node(self, tenant: str, node: Node) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be a private function. rename to _remove_tenant_single_node

"""Remove a single node belonging to a tenant, returns the number of characters removed."""
with self.lock:
if tenant not in self.tenants:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same for this probably don't want to raise error here

f"Cannot remove tenant '{tenant}': tenant does not exist"
)
if (
node not in self.tenant_nodes[tenant]
or tenant not in node.tenant_last_access_time
):
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And here

f"Cannot remove node '{node.text}' from tenant '{tenant}': tenant does not have this node"
)

removed_chars_len: int = len(node.text)
self.tenant_char_count[tenant] -= removed_chars_len
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for checking for divergence, if ever?

self.tenant_nodes[tenant].remove(node)
node.tenant_last_access_time.pop(tenant, None)
# If this node has no more tenants, remove it from the parent
if not node.tenant_last_access_time and node.parent:
node.parent.children.pop(node.text[0], None)

return removed_chars_len

def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's keep the method name all lower cased

Suggested change
def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int:
def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also who's the caller for this method? Should there be a background task running this?

"""Evict nodes from a tenant until the removed character count is at least min_remove_size.
Args:
tenant: The tenant to evict nodes from
min_remove_size: Minimum number of characters to remove
Returns:
int: The actual number of characters removed
"""
with self.lock:
if tenant not in self.tenant_nodes or not self.tenant_nodes[tenant]:
raise ValueError(
f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes"
)

if self.tenant_char_count[tenant] < min_remove_size:
raise ValueError(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just make this method idempotent and not raise those errors?

f"Cannot evict tenant '{tenant}': total character count is less than min_remove_size"
)

# Sort nodes by last access time (oldest first)
nodes_to_evict = sorted(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use a heap so that you don't have to do this sorting ?

self.tenant_nodes[tenant],
key=lambda node: node.tenant_last_access_time.get(tenant, 0),
)

total_chars_removed: int = 0

# Remove nodes until we've reached the minimum removal size
for node in nodes_to_evict.copy():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for node in nodes_to_evict.copy():
while total_chars_removed < min_remove_size

# Use existing function to remove tenant from node
total_chars_removed += self.remove_tenant_single_node(tenant, node)

# Check if we've removed enough characters
if total_chars_removed >= min_remove_size:
break

return total_chars_removed

def get_smallest_tenant(self) -> Optional[str]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again asked AI to rewrite this in a better way, looks better and more readable (rule of thumb: you should be allergic to seeing hardcoded indices (e.g. foo[0]) in python)

def get_smallest_tenant(self) -> Optional[str]:
    """Get the tenant with the smallest total character count."""
    with self.lock:
        return min(self.tenant_char_count, key=self.tenant_char_count.get, default=None)

"""Get the tenant with the smallest total character count."""
with self.lock:
if not self.tenant_char_count:
return None

return min(self.tenant_char_count.items(), key=lambda x: x[1])[0]
Loading