From c691e8b39618651bd9da32851843dd026a93ab68 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Fri, 2 May 2025 10:35:17 -0700 Subject: [PATCH 01/15] Add prefix tree class Signed-off-by: Justin Ji --- .../serve/deployments/routers/prefix_tree.py | 232 ++++++++++++++++++ 1 file changed, 232 insertions(+) create mode 100644 python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py diff --git a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py new file mode 100644 index 0000000000000..a3dc8cce0d0a5 --- /dev/null +++ b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py @@ -0,0 +1,232 @@ +from ray import serve +import time +from collections import defaultdict +from threading import Lock, RLock +import json +from typing import Optional, List, Tuple, Dict, Any, TypeVar, Union, cast + +class Node: + """ + Node in a prefix tree that tracks tenant access time. + + Each node represents a segment of text and can belong to multiple tenants. + """ + def __init__(self, text: str = "", parent: Optional["Node"] = None): + self.text: str = text + self.parent: Optional["Node"] = parent + self.children: Dict[str, "Node"] = {} # Maps char -> Node + self.tenant_last_access_time: Dict[str, int] = {} # Maps tenant -> timestamp in ms (int) + + def __str__(self) -> str: + return f"Node(text='{self.text}', tenants={list(self.tenant_last_access_time.keys())})" + +@serve.deployment(name="TreeDeployment") +class PrefixTree: + """ + Thread-safe multi-tenant prefix tree (approximate radix tree). + + Features: + 1. Stores data for multiple tenants in the same tree structure + 2. Node-level locking for concurrent access + 3. Leaf LRU eviction based on tenant access time + """ + def __init__(self) -> None: + self.root: Node = Node() + self.tenant_char_count: Dict[str, int] = defaultdict(int) # Maps tenant -> character count + self.lock: Lock = Lock() # For operations that need to lock the entire tree + + @staticmethod + def shared_prefix_count(a: str, b: str) -> int: + """Count the number of shared characters at the beginning of two strings.""" + i: int = 0 + for char_a, char_b in zip(a, b): + if char_a == char_b: + i += 1 + else: + break + return i + + def insert(self, text: str, tenant: str) -> None: + """Insert text into tree with given tenant.""" + with self.lock: + curr_node: Node = self.root + timestamp_ms: int = int(time.time() * 1000) + i: int = 0 + while i < len(text): + curr_node.tenant_last_access_time[tenant] = timestamp_ms + first_char: str = text[i] + curr_text: str = text[i:] + if first_char not in curr_node.children: + # No match, create new node + # e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")} + new_node: Node = Node(text=curr_text, parent=curr_node) + new_node.tenant_last_access_time[tenant] = timestamp_ms + + # Increment char count for tenant + self.tenant_char_count[tenant] += len(curr_text) + + curr_node.children[first_char] = new_node + else: + # Match found, check if need to split + matched_node: Node = curr_node.children[first_char] + shared_count: int = self.shared_prefix_count(matched_node.text, curr_text) + + if shared_count < len(matched_node.text): + # Partial match, split at matched point + # Example: + ## Before update: + ### curr_node.children = {"h": Node("helloworld")}, curr_text = "hellothere" -> shared_count = 5 + ### matched_node = Node("helloworld") + + ## During update: + ### Increment tenant_char_count[tenant] by shared_count if matched_node has not seen this tenant before + + ## After update: + ### curr_node.children = {"h": Node("hello", children = {"w": Node("world")})} + ### parent_node = Node("hello"), matched_node = Node("world") + ### Update tenant_last_access_time for parent_node, NOT matched_node + ### (new) curr_text = "there", (new) curr_node = parent_node + ### Continue adding "there" to tree in next iteration + + matched_text: str = matched_node.text[:shared_count] + remaining_text: str = matched_node.text[shared_count:] + + # Update tenant char count for the new split node + if tenant not in matched_node.tenant_last_access_time: + self.tenant_char_count[tenant] += shared_count + + # Create new parent node + new_parent: Node = Node(text=matched_text, parent=curr_node) + new_parent.tenant_last_access_time = matched_node.tenant_last_access_time.copy() + + # Update matched_node + matched_node.text = remaining_text + matched_node.parent = new_parent + + # Connect new parent node to matched_node + new_parent.children[remaining_text[0]] = matched_node + + # Connect current node to new parent + curr_node.children[first_char] = new_parent + + # Move down the tree + curr_node = new_parent + i += shared_count + else: + # Full match + + # Update tenant char count if this is a new tenant for this node + if tenant not in matched_node.tenant_last_access_time: + self.tenant_char_count[tenant] += shared_count + + # # Update tenant last access time + # matched_node.tenant_last_access_time[tenant] = timestamp_ms + + # Move down the tree + curr_node = matched_node + i += shared_count + + def prefix_match(self, text: str, available_tenants: Optional[List[str]] = None) -> Tuple[str, Optional[List[str]]]: + """ + Match text against tree and return (matched_text, matched_tenants). + Does not update access time for the matched tenants (only updates when insert() is called). + """ + with self.lock: + curr_node: Node = self.root + i: int = 0 + text_len: int = len(text) + + while i < text_len: + first_char: str = text[i] + curr_text: str = text[i:] + + if first_char in curr_node.children: + matched_node: Node = curr_node.children[first_char] + + # Check if any of the available tenants match this node + if available_tenants: + if not any(tenant in matched_node.tenant_last_access_time for tenant in available_tenants): + break + + shared_count: int = self.shared_prefix_count(matched_node.text, curr_text) + i += shared_count + curr_node = matched_node + + if shared_count < len(matched_node.text): + # Partial match, stop here + break + else: + # No match found, stop here + break + + # Select the tenants in available_tenants that are in the current node + selected_tenants: Optional[List[str]] = None + if available_tenants: + matching_tenants = [tenant for tenant in available_tenants if tenant in curr_node.tenant_last_access_time] + if matching_tenants: + selected_tenants = matching_tenants + else: + if curr_node.tenant_last_access_time: + selected_tenants = list(curr_node.tenant_last_access_time) + + ret_text: str = text[:i] + return ret_text, selected_tenants + + def get_smallest_tenant(self) -> Optional[str]: + """Get the tenant with the smallest total character count.""" + with self.lock: + if not self.tenant_char_count: + # Return first worker if no data yet + return None + + return min(self.tenant_char_count.items(), key=lambda x: x[1])[0] + + def evict_tenant_by_size(self, max_size: int) -> None: + """Evict nodes for tenants that exceed the maximum tree size.""" + with self.lock: + # Get total tree size + total_size: int = sum(self.tenant_char_count.values()) + + # If tree is smaller than max size, no need to evict + if total_size <= max_size: + return + + # Calculate how much we need to evict + excess: int = total_size - max_size + + # Sort tenants by size (largest first) + sorted_tenants: List[Tuple[str, int]] = sorted( + self.tenant_char_count.items(), + key=lambda x: x[1], + reverse=True + ) + + # Evict from largest tenants first + for tenant, size in sorted_tenants: + # If we've evicted enough, stop + if excess <= 0: + break + + # Calculate how much to evict from this tenant + # Evict at most half of the tenant's size + evict_amount: int = min(excess, size // 2) + + if evict_amount > 0: + # print(f"Evicting {evict_amount} chars from tenant {tenant}") + self.tenant_char_count[tenant] -= evict_amount + excess -= evict_amount + + # print(f"Tree eviction complete. New size: {sum(self.tenant_char_count.values())}") + + def get_tenant_char_count(self) -> Dict[str, int]: + """Get character count for each tenant.""" + with self.lock: + return dict(self.tenant_char_count) + + def remove_tenant(self, tenant: str) -> None: + """Remove all nodes belonging to a tenant.""" + # Would require a traversal of the tree and removing the tenant + # from tenant_last_access_time. Simplifying for now. + with self.lock: + if tenant in self.tenant_char_count: + del self.tenant_char_count[tenant] \ No newline at end of file From 939aa7e84d736b2749d167fd831c20f23f1ccd18 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Fri, 2 May 2025 10:51:09 -0700 Subject: [PATCH 02/15] Linting Signed-off-by: Justin Ji --- .../serve/deployments/routers/prefix_tree.py | 170 ++++++++++-------- 1 file changed, 100 insertions(+), 70 deletions(-) diff --git a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py index a3dc8cce0d0a5..fd79d1be41d71 100644 --- a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py +++ b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py @@ -1,40 +1,50 @@ from ray import serve import time from collections import defaultdict -from threading import Lock, RLock -import json -from typing import Optional, List, Tuple, Dict, Any, TypeVar, Union, cast +from threading import Lock +from typing import Optional, List, Tuple, Dict + class Node: """ Node in a prefix tree that tracks tenant access time. - + Each node represents a segment of text and can belong to multiple tenants. """ + def __init__(self, text: str = "", parent: Optional["Node"] = None): self.text: str = text self.parent: Optional["Node"] = parent self.children: Dict[str, "Node"] = {} # Maps char -> Node - self.tenant_last_access_time: Dict[str, int] = {} # Maps tenant -> timestamp in ms (int) - + self.tenant_last_access_time: Dict[ + str, int + ] = {} # Maps tenant -> timestamp in ms (int) + def __str__(self) -> str: return f"Node(text='{self.text}', tenants={list(self.tenant_last_access_time.keys())})" + @serve.deployment(name="TreeDeployment") class PrefixTree: """ Thread-safe multi-tenant prefix tree (approximate radix tree). - + Features: 1. Stores data for multiple tenants in the same tree structure 2. Node-level locking for concurrent access 3. Leaf LRU eviction based on tenant access time """ + def __init__(self) -> None: self.root: Node = Node() - self.tenant_char_count: Dict[str, int] = defaultdict(int) # Maps tenant -> character count + self.tenant_char_count: Dict[str, int] = defaultdict( + int + ) # Maps tenant -> character count self.lock: Lock = Lock() # For operations that need to lock the entire tree - + self.tenant_nodes: Dict[str, List[Node]] = defaultdict( + list + ) # Maps tenant -> list of nodes it belongs to + @staticmethod def shared_prefix_count(a: str, b: str) -> int: """Count the number of shared characters at the beginning of two strings.""" @@ -61,16 +71,19 @@ def insert(self, text: str, tenant: str) -> None: # e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")} new_node: Node = Node(text=curr_text, parent=curr_node) new_node.tenant_last_access_time[tenant] = timestamp_ms - - # Increment char count for tenant + + # Increment char count for tenant and add node to tenant_nodes self.tenant_char_count[tenant] += len(curr_text) - + self.tenant_nodes[tenant].append(new_node) + curr_node.children[first_char] = new_node else: # Match found, check if need to split matched_node: Node = curr_node.children[first_char] - shared_count: int = self.shared_prefix_count(matched_node.text, curr_text) - + shared_count: int = self.shared_prefix_count( + matched_node.text, curr_text + ) + if shared_count < len(matched_node.text): # Partial match, split at matched point # Example: @@ -94,11 +107,13 @@ def insert(self, text: str, tenant: str) -> None: # Update tenant char count for the new split node if tenant not in matched_node.tenant_last_access_time: self.tenant_char_count[tenant] += shared_count - + # Create new parent node new_parent: Node = Node(text=matched_text, parent=curr_node) - new_parent.tenant_last_access_time = matched_node.tenant_last_access_time.copy() - + new_parent.tenant_last_access_time = ( + matched_node.tenant_last_access_time.copy() + ) + self.tenant_nodes[tenant].append(new_parent) # Update matched_node matched_node.text = remaining_text matched_node.parent = new_parent @@ -108,7 +123,7 @@ def insert(self, text: str, tenant: str) -> None: # Connect current node to new parent curr_node.children[first_char] = new_parent - + # Move down the tree curr_node = new_parent i += shared_count @@ -118,7 +133,7 @@ def insert(self, text: str, tenant: str) -> None: # Update tenant char count if this is a new tenant for this node if tenant not in matched_node.tenant_last_access_time: self.tenant_char_count[tenant] += shared_count - + # # Update tenant last access time # matched_node.tenant_last_access_time[tenant] = timestamp_ms @@ -126,7 +141,9 @@ def insert(self, text: str, tenant: str) -> None: curr_node = matched_node i += shared_count - def prefix_match(self, text: str, available_tenants: Optional[List[str]] = None) -> Tuple[str, Optional[List[str]]]: + def prefix_match( + self, text: str, available_tenants: Optional[List[str]] = None + ) -> Tuple[str, Optional[List[str]]]: """ Match text against tree and return (matched_text, matched_tenants). Does not update access time for the matched tenants (only updates when insert() is called). @@ -135,20 +152,25 @@ def prefix_match(self, text: str, available_tenants: Optional[List[str]] = None) curr_node: Node = self.root i: int = 0 text_len: int = len(text) - + while i < text_len: first_char: str = text[i] curr_text: str = text[i:] - + if first_char in curr_node.children: matched_node: Node = curr_node.children[first_char] # Check if any of the available tenants match this node if available_tenants: - if not any(tenant in matched_node.tenant_last_access_time for tenant in available_tenants): + if not any( + tenant in matched_node.tenant_last_access_time + for tenant in available_tenants + ): break - shared_count: int = self.shared_prefix_count(matched_node.text, curr_text) + shared_count: int = self.shared_prefix_count( + matched_node.text, curr_text + ) i += shared_count curr_node = matched_node @@ -158,75 +180,83 @@ def prefix_match(self, text: str, available_tenants: Optional[List[str]] = None) else: # No match found, stop here break - + # Select the tenants in available_tenants that are in the current node selected_tenants: Optional[List[str]] = None if available_tenants: - matching_tenants = [tenant for tenant in available_tenants if tenant in curr_node.tenant_last_access_time] + matching_tenants = [ + tenant + for tenant in available_tenants + if tenant in curr_node.tenant_last_access_time + ] if matching_tenants: selected_tenants = matching_tenants else: if curr_node.tenant_last_access_time: selected_tenants = list(curr_node.tenant_last_access_time) - + ret_text: str = text[:i] return ret_text, selected_tenants - + def get_smallest_tenant(self) -> Optional[str]: """Get the tenant with the smallest total character count.""" with self.lock: if not self.tenant_char_count: # Return first worker if no data yet return None - + return min(self.tenant_char_count.items(), key=lambda x: x[1])[0] - - def evict_tenant_by_size(self, max_size: int) -> None: - """Evict nodes for tenants that exceed the maximum tree size.""" - with self.lock: - # Get total tree size - total_size: int = sum(self.tenant_char_count.values()) - - # If tree is smaller than max size, no need to evict - if total_size <= max_size: - return - - # Calculate how much we need to evict - excess: int = total_size - max_size - - # Sort tenants by size (largest first) - sorted_tenants: List[Tuple[str, int]] = sorted( - self.tenant_char_count.items(), - key=lambda x: x[1], - reverse=True - ) - - # Evict from largest tenants first - for tenant, size in sorted_tenants: - # If we've evicted enough, stop - if excess <= 0: - break - - # Calculate how much to evict from this tenant - # Evict at most half of the tenant's size - evict_amount: int = min(excess, size // 2) - - if evict_amount > 0: - # print(f"Evicting {evict_amount} chars from tenant {tenant}") - self.tenant_char_count[tenant] -= evict_amount - excess -= evict_amount - - # print(f"Tree eviction complete. New size: {sum(self.tenant_char_count.values())}") - + + # def evict_tenant_by_size(self, max_size: int) -> None: + # """Evict nodes for tenants that exceed the maximum tree size.""" + # with self.lock: + # # Get total tree size + # total_size: int = sum(self.tenant_char_count.values()) + + # # If tree is smaller than max size, no need to evict + # if total_size <= max_size: + # return + + # # Calculate how much we need to evict + # excess: int = total_size - max_size + + # # Sort tenants by size (largest first) + # sorted_tenants: List[Tuple[str, int]] = sorted( + # self.tenant_char_count.items(), + # key=lambda x: x[1], + # reverse=True + # ) + + # # Evict from largest tenants first + # for tenant, size in sorted_tenants: + # # If we've evicted enough, stop + # if excess <= 0: + # break + + # # Calculate how much to evict from this tenant + # # Evict at most half of the tenant's size + # evict_amount: int = min(excess, size // 2) + + # if evict_amount > 0: + # # print(f"Evicting {evict_amount} chars from tenant {tenant}") + # self.tenant_char_count[tenant] -= evict_amount + # excess -= evict_amount + + # # print(f"Tree eviction complete. New size: {sum(self.tenant_char_count.values())}") + def get_tenant_char_count(self) -> Dict[str, int]: """Get character count for each tenant.""" with self.lock: return dict(self.tenant_char_count) - + def remove_tenant(self, tenant: str) -> None: """Remove all nodes belonging to a tenant.""" # Would require a traversal of the tree and removing the tenant # from tenant_last_access_time. Simplifying for now. with self.lock: - if tenant in self.tenant_char_count: - del self.tenant_char_count[tenant] \ No newline at end of file + for node in self.tenant_nodes[tenant]: + node.tenant_last_access_time.pop(tenant, None) + if not node.tenant_last_access_time: + node.parent.children.pop(node.text[0], None) + self.tenant_char_count.pop(tenant, None) + self.tenant_nodes.pop(tenant, None) From 8eb8a9c711abef1bf8616587cc6f9bf76d51513c Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Fri, 2 May 2025 12:54:09 -0700 Subject: [PATCH 03/15] Add test cases Signed-off-by: Justin Ji --- .../serve/deployments/routers/prefix_tree.py | 113 ++++++++------ .../serve/cpu/deployments/test_prefix_tree.py | 147 ++++++++++++++++++ 2 files changed, 213 insertions(+), 47 deletions(-) create mode 100644 python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py diff --git a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py index fd79d1be41d71..9d208edb14fda 100644 --- a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py +++ b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py @@ -1,7 +1,7 @@ from ray import serve import time from collections import defaultdict -from threading import Lock +from threading import RLock from typing import Optional, List, Tuple, Dict @@ -40,11 +40,23 @@ def __init__(self) -> None: self.tenant_char_count: Dict[str, int] = defaultdict( int ) # Maps tenant -> character count - self.lock: Lock = Lock() # For operations that need to lock the entire tree + self.lock: RLock = RLock() # For operations that need to lock the entire tree self.tenant_nodes: Dict[str, List[Node]] = defaultdict( list ) # Maps tenant -> list of nodes it belongs to + def get_root(self) -> Node: + return self.root + + def get_tenant_char_count(self) -> Dict[str, int]: + return self.tenant_char_count + + def get_tenant_nodes(self) -> Dict[str, List[Node]]: + return self.tenant_nodes + + def to_string(self) -> str: + return f"PrefixTree(root={self.root.__str__()}, tenant_char_count={self.tenant_char_count}, tenant_nodes={self.tenant_nodes})" + @staticmethod def shared_prefix_count(a: str, b: str) -> int: """Count the number of shared characters at the beginning of two strings.""" @@ -202,61 +214,68 @@ def get_smallest_tenant(self) -> Optional[str]: """Get the tenant with the smallest total character count.""" with self.lock: if not self.tenant_char_count: - # Return first worker if no data yet return None return min(self.tenant_char_count.items(), key=lambda x: x[1])[0] - # def evict_tenant_by_size(self, max_size: int) -> None: - # """Evict nodes for tenants that exceed the maximum tree size.""" - # with self.lock: - # # Get total tree size - # total_size: int = sum(self.tenant_char_count.values()) - - # # If tree is smaller than max size, no need to evict - # if total_size <= max_size: - # return + def get_tenant_char_count(self) -> Dict[str, int]: + """Get character count for each tenant.""" + with self.lock: + return dict(self.tenant_char_count) - # # Calculate how much we need to evict - # excess: int = total_size - max_size + def remove_tenant_entirely(self, tenant: str) -> int: + """Remove all nodes belonging to a tenant, returns the number of characters removed. Also removes the tenant from tenant_char_count and tenant_nodes.""" + with self.lock: + total_chars_removed: int = 0 + for node in self.tenant_nodes[tenant].copy(): + total_chars_removed += self.remove_tenant_single_node(tenant, node) + self.tenant_nodes.pop(tenant, None) + self.tenant_char_count.pop(tenant, None) + return total_chars_removed - # # Sort tenants by size (largest first) - # sorted_tenants: List[Tuple[str, int]] = sorted( - # self.tenant_char_count.items(), - # key=lambda x: x[1], - # reverse=True - # ) + def remove_tenant_single_node(self, tenant: str, node: Node) -> int: + """Remove a single node belonging to a tenant, returns the number of characters removed.""" + with self.lock: + removed_chars_len: int = len(node.text) + self.tenant_char_count[tenant] -= removed_chars_len + self.tenant_nodes[tenant].remove(node) + node.tenant_last_access_time.pop(tenant, None) - # # Evict from largest tenants first - # for tenant, size in sorted_tenants: - # # If we've evicted enough, stop - # if excess <= 0: - # break + # If this node has no more tenants, remove it from the parent + if not node.tenant_last_access_time: + node.parent.children.pop(node.text[0], None) - # # Calculate how much to evict from this tenant - # # Evict at most half of the tenant's size - # evict_amount: int = min(excess, size // 2) + return removed_chars_len - # if evict_amount > 0: - # # print(f"Evicting {evict_amount} chars from tenant {tenant}") - # self.tenant_char_count[tenant] -= evict_amount - # excess -= evict_amount + def evict_tenant(self, tenant: str, min_remove_size: int) -> int: + """Evict nodes from a tenant until the removed character count is at least min_remove_size. - # # print(f"Tree eviction complete. New size: {sum(self.tenant_char_count.values())}") + Args: + tenant: The tenant to evict nodes from + min_remove_size: Minimum number of characters to remove - def get_tenant_char_count(self) -> Dict[str, int]: - """Get character count for each tenant.""" + Returns: + int: The actual number of characters removed + """ with self.lock: - return dict(self.tenant_char_count) + if tenant not in self.tenant_nodes or not self.tenant_nodes[tenant]: + return 0 - def remove_tenant(self, tenant: str) -> None: - """Remove all nodes belonging to a tenant.""" - # Would require a traversal of the tree and removing the tenant - # from tenant_last_access_time. Simplifying for now. - with self.lock: - for node in self.tenant_nodes[tenant]: - node.tenant_last_access_time.pop(tenant, None) - if not node.tenant_last_access_time: - node.parent.children.pop(node.text[0], None) - self.tenant_char_count.pop(tenant, None) - self.tenant_nodes.pop(tenant, None) + # Sort nodes by last access time (oldest first) + nodes_to_evict = sorted( + self.tenant_nodes[tenant], + key=lambda node: node.tenant_last_access_time.get(tenant, 0), + ) + + total_chars_removed: int = 0 + + # Remove nodes until we've reached the minimum removal size + for node in nodes_to_evict.copy(): + # Use existing function to remove tenant from node + total_chars_removed += self.remove_tenant_single_node(tenant, node) + + # Check if we've removed enough characters + if total_chars_removed >= min_remove_size: + break + + return total_chars_removed diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py new file mode 100644 index 0000000000000..b8d6eda04a8dc --- /dev/null +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -0,0 +1,147 @@ +import pytest +import time +import ray +from ray import serve + +from ray.llm._internal.serve.deployments.routers.prefix_tree import PrefixTree + + +@pytest.fixture(scope="module", autouse=True) +def serve_instance(): + # Start Ray and Serve once per test module + ray.init(ignore_reinit_error=True) + serve.start(detached=True) + yield + serve.shutdown() + ray.shutdown() + + +@pytest.mark.asyncio +async def test_insert_and_basic_match(): + # Deploy a clean PrefixTree + tree = serve.run(PrefixTree.bind()) + + # Insert and match exact string + await tree.insert.remote("hello", "tenant-A") + matched_text, tenants = await tree.prefix_match.remote("hello") + assert matched_text == "hello" + assert tenants == ["tenant-A"] + + +@pytest.mark.asyncio +async def test_tree_splits_nodes_on_partial_match(): + tree = serve.run(PrefixTree.bind()) + await tree.insert.remote("helloworld", "A") + await tree.insert.remote("hellothere", "B") + + # After inserting both, the root should have one child "h" + root = await tree.get_root.remote() + h_node = root.children.get("h") + assert h_node is not None + + +@pytest.mark.asyncio +async def test_no_match(): + tree = serve.run(PrefixTree.bind()) + matched_text, tenants = await tree.prefix_match.remote("hello") + assert matched_text == "" + assert tenants is None + + +@pytest.mark.asyncio +async def test_duplicate_insertion_no_double_count(): + tree = serve.run(PrefixTree.bind()) + + await tree.insert.remote("foo", "T1") + await tree.insert.remote("foo", "T1") # duplicate + + counts = await tree.get_tenant_char_count.remote() + # Should count 'foo' only once + assert counts.get("T1", 0) == 3 + + +@pytest.mark.asyncio +async def test_shared_prefix_splitting_and_branching(): + tree = serve.run(PrefixTree.bind()) + + await tree.insert.remote("helloworld", "A") + await tree.insert.remote("hellothere", "B") + + text_a, tenants_a = await tree.prefix_match.remote("helloworld") + text_b, tenants_b = await tree.prefix_match.remote("hellothere") + + assert text_a == "helloworld" + assert tenants_a == ["A"] + assert text_b == "hellothere" + assert tenants_b == ["B"] + + +@pytest.mark.asyncio +async def test_prefix_match_partial_and_filter(): + tree = serve.run(PrefixTree.bind()) + + await tree.insert.remote("apple", "X") + await tree.insert.remote("apricot", "Y") + + # Partial match for 'application' -> 'appl' + text, tenants = await tree.prefix_match.remote("application") + assert text == "appl" + assert tenants == ["X"] + + # Filter by available_tenants=['X'] on 'apricot' + text_fx, tenants_fx = await tree.prefix_match.remote("apricot", ["X"]) + assert text_fx == "ap" + assert tenants_fx == ["X"] + + # Filter by non-existent tenant yields no tenants + text_fz, tenants_fz = await tree.prefix_match.remote("apricot", ["Z"]) + assert text_fz == "" + assert tenants_fz is None + + +@pytest.mark.asyncio +async def test_remove_and_get_smallest_and_evict(): + tree = serve.run(PrefixTree.bind()) + + # Test removal and char count + await tree.insert.remote("cat", "T1") + await tree.insert.remote("dog", "T1") + + counts = await tree.get_tenant_char_count.remote() + assert counts.get("T1") == len("cat") + len("dog") + + # Remove entire tenant + removed = await tree.remove_tenant_entirely.remote("T1") + assert removed == len("cat") + len("dog") + + counts_after = await tree.get_tenant_char_count.remote() + assert "T1" not in counts_after + + # Test eviction LRU behavior + await tree.insert.remote("a", "T2") + time.sleep(0.001) + await tree.insert.remote("bb", "T2") + time.sleep(0.001) + await tree.insert.remote("ccc", "T2") + + before = (await tree.get_tenant_char_count.remote())["T2"] + evicted = await tree.evict_tenant.remote("T2", 2) + after = (await tree.get_tenant_char_count.remote())["T2"] + + assert evicted >= 2 + assert before - after == evicted + + +@pytest.mark.asyncio +async def test_get_smallest_tenant(): + tree = serve.run(PrefixTree.bind()) + await tree.insert.remote("aaaa", "A") + await tree.insert.remote("bb", "B") + smallest = await tree.get_smallest_tenant.remote() + assert smallest == "B" + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(["-v", __file__])) From 3e55d54d26b7343c1fa2660763b28dfb3df16b1a Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Fri, 2 May 2025 18:08:17 -0700 Subject: [PATCH 04/15] Implement eviction, tests passing Signed-off-by: Justin Ji --- .../serve/deployments/routers/prefix_tree.py | 202 ++++++--- .../serve/cpu/deployments/test_prefix_tree.py | 397 ++++++++++++++---- 2 files changed, 455 insertions(+), 144 deletions(-) diff --git a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py index 9d208edb14fda..a5de6674aacf9 100644 --- a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py +++ b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py @@ -2,7 +2,7 @@ import time from collections import defaultdict from threading import RLock -from typing import Optional, List, Tuple, Dict +from typing import Optional, List, Tuple, Dict, Set class Node: @@ -20,8 +20,8 @@ def __init__(self, text: str = "", parent: Optional["Node"] = None): str, int ] = {} # Maps tenant -> timestamp in ms (int) - def __str__(self) -> str: - return f"Node(text='{self.text}', tenants={list(self.tenant_last_access_time.keys())})" + def to_string(self) -> str: + return f"Node(text='{self.text}', parent={self.parent}, children={self.children}, tenant_last_access_time={self.tenant_last_access_time})" @serve.deployment(name="TreeDeployment") @@ -36,26 +36,30 @@ class PrefixTree: """ def __init__(self) -> None: + self.lock: RLock = RLock() self.root: Node = Node() - self.tenant_char_count: Dict[str, int] = defaultdict( - int - ) # Maps tenant -> character count - self.lock: RLock = RLock() # For operations that need to lock the entire tree - self.tenant_nodes: Dict[str, List[Node]] = defaultdict( - list - ) # Maps tenant -> list of nodes it belongs to - - def get_root(self) -> Node: - return self.root + self.tenants: Set[str] = set() + self.tenant_char_count: Dict[str, int] = {} + self.tenant_nodes: Dict[str, Set[Node]] = {} - def get_tenant_char_count(self) -> Dict[str, int]: - return self.tenant_char_count - - def get_tenant_nodes(self) -> Dict[str, List[Node]]: - return self.tenant_nodes + def reset(self) -> None: + """Reset the tree to an empty state.""" + with self.lock: + self.root = Node() + self.tenants = set() + self.tenant_char_count = {} + self.tenant_nodes = {} + + def to_dict(self) -> Dict: + return { + "root": self.root, + "tenants": self.tenants, + "tenant_char_count": self.tenant_char_count, + "tenant_nodes": self.tenant_nodes + } def to_string(self) -> str: - return f"PrefixTree(root={self.root.__str__()}, tenant_char_count={self.tenant_char_count}, tenant_nodes={self.tenant_nodes})" + return f"PrefixTree(root={self.root.__str__()}, tenants={self.tenants}, tenant_char_count={self.tenant_char_count}, tenant_nodes={self.tenant_nodes})" @staticmethod def shared_prefix_count(a: str, b: str) -> int: @@ -68,14 +72,69 @@ def shared_prefix_count(a: str, b: str) -> int: break return i - def insert(self, text: str, tenant: str) -> None: - """Insert text into tree with given tenant.""" + # def insert(self, text: str, tenant: str) -> Node: + # """Insert text into tree with given tenant. Returns the node that was inserted (or the existing node if it was updated).""" + # with self.lock: + # if tenant not in self.tenants: + # raise ValueError(f"Cannot insert text for tenant '{tenant}': tenant does not exist") + + # curr_node: Node = self.root + # timestamp_ms: int = int(time.time() * 1000) + # i: int = 0 + # while i < len(text): + # self.tenant_nodes[tenant].add(curr_node) + + # first_char: str = text[i] + # curr_text: str = text[i:] + # if first_char not in curr_node.children: + # # No match, create new node + # # e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")} + # new_node: Node = Node(text=curr_text, parent=curr_node) + # new_node.tenant_last_access_time[tenant] = timestamp_ms + # curr_node.children[first_char] = new_node + + # # Match found, check if need to split + # matched_node: Node = curr_node.children[first_char] + # shared_count: int = self.shared_prefix_count( + # matched_node.text, curr_text + # ) + # if shared_count == len(matched_node.text): + # # Full match, move down the tree + # if tenant not in matched_node.tenant_last_access_time: + # self.tenant_char_count[tenant] += shared_count + # matched_node.tenant_last_access_time[tenant] = timestamp_ms + # curr_node = matched_node + # else: + # # Partial match, split at matched point + # matched_text: str = matched_node.text[:shared_count] + # remaining_text: str = matched_node.text[shared_count:] + # new_parent: Node = Node(text=matched_text, parent=curr_node) + # matched_node.text = remaining_text + # matched_node.parent = new_parent + # new_parent.children[remaining_text[0]] = matched_node + # if tenant not in new_parent.tenant_last_access_time: + # self.tenant_char_count[tenant] += shared_count + # new_parent.tenant_last_access_time[tenant] = timestamp_ms + # curr_node = new_parent + + # i += shared_count + + # self.tenant_nodes[tenant].add(curr_node) + # return curr_node + + def insert(self, text: str, tenant: str) -> Node: + """Insert text into tree with given tenant. Returns the node that was inserted (or the existing node if it was updated).""" with self.lock: + if tenant not in self.tenants: + raise ValueError(f"Cannot insert text for tenant '{tenant}': tenant does not exist") + curr_node: Node = self.root timestamp_ms: int = int(time.time() * 1000) i: int = 0 while i < len(text): curr_node.tenant_last_access_time[tenant] = timestamp_ms + self.tenant_nodes[tenant].add(curr_node) + first_char: str = text[i] curr_text: str = text[i:] if first_char not in curr_node.children: @@ -86,7 +145,7 @@ def insert(self, text: str, tenant: str) -> None: # Increment char count for tenant and add node to tenant_nodes self.tenant_char_count[tenant] += len(curr_text) - self.tenant_nodes[tenant].append(new_node) + self.tenant_nodes[tenant].add(new_node) curr_node.children[first_char] = new_node else: @@ -125,7 +184,6 @@ def insert(self, text: str, tenant: str) -> None: new_parent.tenant_last_access_time = ( matched_node.tenant_last_access_time.copy() ) - self.tenant_nodes[tenant].append(new_parent) # Update matched_node matched_node.text = remaining_text matched_node.parent = new_parent @@ -152,14 +210,28 @@ def insert(self, text: str, tenant: str) -> None: # Move down the tree curr_node = matched_node i += shared_count - + curr_node.tenant_last_access_time[tenant] = timestamp_ms + self.tenant_nodes[tenant].add(curr_node) + return curr_node def prefix_match( self, text: str, available_tenants: Optional[List[str]] = None ) -> Tuple[str, Optional[List[str]]]: """ Match text against tree and return (matched_text, matched_tenants). Does not update access time for the matched tenants (only updates when insert() is called). + If available_tenants is not provided, all tenants are considered. """ + if available_tenants: + # Filter available_tenants to only include those that exist in the tree + available_tenants = [ + tenant for tenant in available_tenants + if tenant in self.tenants + ] + if not available_tenants: + return "", None + else: + available_tenants = list(self.tenants) + with self.lock: curr_node: Node = self.root i: int = 0 @@ -173,12 +245,11 @@ def prefix_match( matched_node: Node = curr_node.children[first_char] # Check if any of the available tenants match this node - if available_tenants: - if not any( - tenant in matched_node.tenant_last_access_time - for tenant in available_tenants - ): - break + if not any( + tenant in matched_node.tenant_last_access_time + for tenant in available_tenants + ): + break shared_count: int = self.shared_prefix_count( matched_node.text, curr_text @@ -195,59 +266,65 @@ def prefix_match( # Select the tenants in available_tenants that are in the current node selected_tenants: Optional[List[str]] = None - if available_tenants: - matching_tenants = [ - tenant - for tenant in available_tenants - if tenant in curr_node.tenant_last_access_time - ] - if matching_tenants: - selected_tenants = matching_tenants - else: - if curr_node.tenant_last_access_time: - selected_tenants = list(curr_node.tenant_last_access_time) + matching_tenants = [ + tenant + for tenant in available_tenants + if tenant in curr_node.tenant_last_access_time + ] + if matching_tenants: + selected_tenants = matching_tenants ret_text: str = text[:i] return ret_text, selected_tenants - def get_smallest_tenant(self) -> Optional[str]: - """Get the tenant with the smallest total character count.""" + + def add_tenant(self, tenant: str) -> None: + """Add a tenant to the tree.""" with self.lock: - if not self.tenant_char_count: - return None + if tenant in self.tenants: + raise ValueError(f"Cannot add tenant '{tenant}': tenant already exists") - return min(self.tenant_char_count.items(), key=lambda x: x[1])[0] + self.tenants.add(tenant) + self.tenant_char_count[tenant] = 0 + self.tenant_nodes[tenant] = set() - def get_tenant_char_count(self) -> Dict[str, int]: - """Get character count for each tenant.""" - with self.lock: - return dict(self.tenant_char_count) - def remove_tenant_entirely(self, tenant: str) -> int: - """Remove all nodes belonging to a tenant, returns the number of characters removed. Also removes the tenant from tenant_char_count and tenant_nodes.""" + def remove_tenant(self, tenant: str) -> int: + """Remove a tenant's nodes from the tree, returns the number of characters removed. Also removes the tenant from tenants, tenant_char_count, and tenant_nodes.""" with self.lock: + if tenant not in self.tenants: + raise ValueError(f"Cannot remove tenant '{tenant}': tenant does not exist") + total_chars_removed: int = 0 for node in self.tenant_nodes[tenant].copy(): total_chars_removed += self.remove_tenant_single_node(tenant, node) + + self.tenants.remove(tenant) self.tenant_nodes.pop(tenant, None) self.tenant_char_count.pop(tenant, None) return total_chars_removed + def remove_tenant_single_node(self, tenant: str, node: Node) -> int: """Remove a single node belonging to a tenant, returns the number of characters removed.""" with self.lock: + if tenant not in self.tenants: + raise ValueError(f"Cannot remove tenant '{tenant}': tenant does not exist") + if node not in self.tenant_nodes[tenant] or tenant not in node.tenant_last_access_time: + raise ValueError(f"Cannot remove node '{node.text}' from tenant '{tenant}': tenant does not have this node") + removed_chars_len: int = len(node.text) self.tenant_char_count[tenant] -= removed_chars_len self.tenant_nodes[tenant].remove(node) node.tenant_last_access_time.pop(tenant, None) - # If this node has no more tenants, remove it from the parent - if not node.tenant_last_access_time: + if not node.tenant_last_access_time and node.parent: node.parent.children.pop(node.text[0], None) return removed_chars_len - def evict_tenant(self, tenant: str, min_remove_size: int) -> int: + + def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int: """Evict nodes from a tenant until the removed character count is at least min_remove_size. Args: @@ -259,7 +336,10 @@ def evict_tenant(self, tenant: str, min_remove_size: int) -> int: """ with self.lock: if tenant not in self.tenant_nodes or not self.tenant_nodes[tenant]: - return 0 + raise ValueError(f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes") + + if self.tenant_char_count[tenant] < min_remove_size: + raise ValueError(f"Cannot evict tenant '{tenant}': total character count is less than min_remove_size") # Sort nodes by last access time (oldest first) nodes_to_evict = sorted( @@ -267,6 +347,7 @@ def evict_tenant(self, tenant: str, min_remove_size: int) -> int: key=lambda node: node.tenant_last_access_time.get(tenant, 0), ) + total_chars_removed: int = 0 # Remove nodes until we've reached the minimum removal size @@ -279,3 +360,12 @@ def evict_tenant(self, tenant: str, min_remove_size: int) -> int: break return total_chars_removed + + + def get_smallest_tenant(self) -> Optional[str]: + """Get the tenant with the smallest total character count.""" + with self.lock: + if not self.tenant_char_count: + return None + + return min(self.tenant_char_count.items(), key=lambda x: x[1])[0] \ No newline at end of file diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index b8d6eda04a8dc..e5777b73ca418 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -17,131 +17,352 @@ def serve_instance(): @pytest.mark.asyncio -async def test_insert_and_basic_match(): - # Deploy a clean PrefixTree +async def test_add_tenant(): + """Test adding tenants to the tree.""" tree = serve.run(PrefixTree.bind()) - - # Insert and match exact string - await tree.insert.remote("hello", "tenant-A") - matched_text, tenants = await tree.prefix_match.remote("hello") - assert matched_text == "hello" - assert tenants == ["tenant-A"] + + # 1. Test basic tenant addition + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + tree_rep = await tree.to_dict.remote() + assert "tenant_1" in tree_rep["tenants"] + assert tree_rep["tenant_char_count"]["tenant_1"] == 0 + assert tree_rep["tenant_nodes"]["tenant_1"] == set() + + # 2. Test adding duplicate tenant raises ValueError + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + with pytest.raises(ValueError): + await tree.add_tenant.remote("tenant_1") @pytest.mark.asyncio -async def test_tree_splits_nodes_on_partial_match(): +async def test_insert(): + """Test the insert functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) - await tree.insert.remote("helloworld", "A") - await tree.insert.remote("hellothere", "B") - - # After inserting both, the root should have one child "h" - root = await tree.get_root.remote() + + # 1. Test basic insertion + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.insert.remote("hello", "tenant_1") + matched_text, tenants = await tree.prefix_match.remote("hello") + assert matched_text == "hello" + assert tenants == ["tenant_1"] + + tree_rep = await tree.to_dict.remote() + assert tree_rep["tenant_char_count"]["tenant_1"] == 5 + assert len(tree_rep["tenant_nodes"]["tenant_1"]) == 2 + + # 2. Test duplicate insertion doesn't double count + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("foo", "tenant_1") + await tree.insert.remote("foo", "tenant_1") # duplicate + await tree.insert.remote("bar", "tenant_2") + + tree_rep = await tree.to_dict.remote() + assert tree_rep["tenant_char_count"]["tenant_1"] == 3 + assert tree_rep["tenant_char_count"]["tenant_2"] == 3 + + # 3. Test node splitting on partial match + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("helloworld", "tenant_1") + await tree.insert.remote("hellothere", "tenant_2") + + tree_rep = await tree.to_dict.remote() + root = tree_rep["root"] h_node = root.children.get("h") assert h_node is not None + assert h_node.text == "hello" + assert h_node.children.get("w").text == "world" + assert h_node.children.get("t").text == "there" + + # 4. Test inserting for non-existent tenant raises ValueError + await tree.reset.remote() + with pytest.raises(ValueError): + await tree.insert.remote("hello", "nonexistent_tenant") @pytest.mark.asyncio -async def test_no_match(): +async def test_prefix_match(): + """Test the prefix_match functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) + + # 1. Test no match + await tree.reset.remote() matched_text, tenants = await tree.prefix_match.remote("hello") assert matched_text == "" assert tenants is None - - -@pytest.mark.asyncio -async def test_duplicate_insertion_no_double_count(): - tree = serve.run(PrefixTree.bind()) - - await tree.insert.remote("foo", "T1") - await tree.insert.remote("foo", "T1") # duplicate - - counts = await tree.get_tenant_char_count.remote() - # Should count 'foo' only once - assert counts.get("T1", 0) == 3 - - -@pytest.mark.asyncio -async def test_shared_prefix_splitting_and_branching(): - tree = serve.run(PrefixTree.bind()) - - await tree.insert.remote("helloworld", "A") - await tree.insert.remote("hellothere", "B") - + + # 2. Test match with non-existing prefix returns empty string and all tenants + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("hello", "tenant_1") + await tree.insert.remote("hellothere", "tenant_2") + matched_text, tenants = await tree.prefix_match.remote("foobar") + assert matched_text == "" + assert len(tenants) == 2 + assert "tenant_1" in tenants + assert "tenant_2" in tenants + + # 3. Test exact match + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.insert.remote("hello", "tenant_1") + matched_text, tenants = await tree.prefix_match.remote("hello") + assert matched_text == "hello" + assert tenants == ["tenant_1"] + + + # 4. Test partial match + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("apple", "tenant_1") + await tree.insert.remote("apricot", "tenant_2") + text, tenants = await tree.prefix_match.remote("application") + assert text == "appl" + assert tenants == ["tenant_1"] + + # 5. Test match by tenant + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("apple", "tenant_1") + await tree.insert.remote("apricot", "tenant_2") + text, tenants = await tree.prefix_match.remote("application", ["tenant_2"]) + assert text == "ap" + assert tenants == ["tenant_2"] + + # 6. Test match by non-existent tenant + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("apple", "tenant_1") + await tree.insert.remote("apricot", "tenant_2") + text, tenants = await tree.prefix_match.remote("application", ["tenant_3"]) + assert text == "" + assert tenants is None + + # 7. Test shared prefix matching with branches + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("helloworld", "tenant_1") + await tree.insert.remote("hellothere", "tenant_2") text_a, tenants_a = await tree.prefix_match.remote("helloworld") - text_b, tenants_b = await tree.prefix_match.remote("hellothere") - + text_b, tenants_b = await tree.prefix_match.remote("hellothereworld") assert text_a == "helloworld" - assert tenants_a == ["A"] + assert tenants_a == ["tenant_1"] assert text_b == "hellothere" - assert tenants_b == ["B"] + assert tenants_b == ["tenant_2"] @pytest.mark.asyncio -async def test_prefix_match_partial_and_filter(): +async def test_remove_tenant(): + """Test removing a tenant from the tree.""" tree = serve.run(PrefixTree.bind()) - - await tree.insert.remote("apple", "X") - await tree.insert.remote("apricot", "Y") - - # Partial match for 'application' -> 'appl' - text, tenants = await tree.prefix_match.remote("application") - assert text == "appl" - assert tenants == ["X"] - - # Filter by available_tenants=['X'] on 'apricot' - text_fx, tenants_fx = await tree.prefix_match.remote("apricot", ["X"]) - assert text_fx == "ap" - assert tenants_fx == ["X"] - - # Filter by non-existent tenant yields no tenants - text_fz, tenants_fz = await tree.prefix_match.remote("apricot", ["Z"]) - assert text_fz == "" - assert tenants_fz is None - + + # 1. Test basic tenant removal + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.insert.remote("hello", "tenant_1") + removed = await tree.remove_tenant.remote("tenant_1") + assert removed == 5 + + tree_rep = await tree.to_dict.remote() + assert "tenant_1" not in tree_rep["tenants"] + assert "tenant_1" not in tree_rep["tenant_char_count"] + assert "tenant_1" not in tree_rep["tenant_nodes"] + + # 2. Test removing tenant with multiple nodes + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.insert.remote("cat", "tenant_1") + await tree.insert.remote("dog", "tenant_1") + removed = await tree.remove_tenant.remote("tenant_1") + assert removed == len("cat") + len("dog") + + # 3. Test removing non-existent tenant raises ValueError + await tree.reset.remote() + with pytest.raises(ValueError): + await tree.remove_tenant.remote("nonexistent_tenant") + + # 4. Test tree structure after removing tenant + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("hello", "tenant_1") + await tree.insert.remote("hello", "tenant_2") + + # Remove tenant_1, verify tenant_2 still works + await tree.remove_tenant.remote("tenant_1") + + tree_rep = await tree.to_dict.remote() + assert "tenant_1" not in tree_rep["tenants"] + assert "tenant_2" in tree_rep["tenants"] + + matched_text, tenants = await tree.prefix_match.remote("hello") + assert matched_text == "hello" + assert tenants == ["tenant_2"] + + # 5. Test removing the last tenant from a node removes the node + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("unique1", "tenant_1") + await tree.insert.remote("unique2", "tenant_2") + + # Remove tenant_1 + await tree.remove_tenant.remote("tenant_1") + + tree_rep = await tree.to_dict.remote() + root = tree_rep["root"] + # 'u' node should only have one child now ('2' from unique2) + assert 'u' in root.children + assert '2' in root.children['u'].children # '2' from unique2 + assert len(root.children['u'].children) == 1 + @pytest.mark.asyncio -async def test_remove_and_get_smallest_and_evict(): +async def test_remove_tenant_single_node(): + """Test removing a single node for a tenant.""" tree = serve.run(PrefixTree.bind()) + + + # # 1. Test removing a single node + # TEST FAILS: Ray creates new node instances when making remote calls? + # The node from insert.remote() is not identity-equal to the one in tenant_nodes + + # await tree.reset.remote() + # await tree.add_tenant.remote("tenant_1") + # h_node = await tree.insert.remote("hello", "tenant_1") + + # removed = await tree.remove_tenant_single_node.remote("tenant_1", h_node) + # assert removed == 5 + + # tree_rep = await tree.to_dict.remote() + # assert tree_rep["tenant_char_count"]["tenant_1"] == 0 + # assert tree_rep["tenant_nodes"]["tenant_1"] == set() + + # 2. Test removing node for non-existent tenant raises ValueError + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.insert.remote("hello", "tenant_1") + + tree_rep = await tree.to_dict.remote() + root = tree_rep["root"] + h_node = root.children.get("h") + + with pytest.raises(ValueError): + await tree.remove_tenant_single_node.remote("nonexistent_tenant", h_node) + + # 3. Test removing node that doesn't belong to tenant raises ValueError + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("hello", "tenant_1") + + tree_rep = await tree.to_dict.remote() + root = tree_rep["root"] + h_node = root.children.get("h") + + with pytest.raises(ValueError): + await tree.remove_tenant_single_node.remote("tenant_2", h_node) - # Test removal and char count - await tree.insert.remote("cat", "T1") - await tree.insert.remote("dog", "T1") - - counts = await tree.get_tenant_char_count.remote() - assert counts.get("T1") == len("cat") + len("dog") - - # Remove entire tenant - removed = await tree.remove_tenant_entirely.remote("T1") - assert removed == len("cat") + len("dog") - - counts_after = await tree.get_tenant_char_count.remote() - assert "T1" not in counts_after - # Test eviction LRU behavior - await tree.insert.remote("a", "T2") +@pytest.mark.asyncio +async def test_evict_tenant_by_LRU(): + """Test the evict_tenant_by_LRU functionality of PrefixTree.""" + tree = serve.run(PrefixTree.bind()) + + # 1. Test eviction with LRU ordering + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.insert.remote("a", "tenant_1") time.sleep(0.001) - await tree.insert.remote("bb", "T2") + await tree.insert.remote("bb", "tenant_1") time.sleep(0.001) - await tree.insert.remote("ccc", "T2") - - before = (await tree.get_tenant_char_count.remote())["T2"] - evicted = await tree.evict_tenant.remote("T2", 2) - after = (await tree.get_tenant_char_count.remote())["T2"] - - assert evicted >= 2 + await tree.insert.remote("ccc", "tenant_1") + + tree_rep = await tree.to_dict.remote() + before = tree_rep["tenant_char_count"]["tenant_1"] + + evicted = await tree.evict_tenant_by_LRU.remote("tenant_1", 2) + + tree_rep = await tree.to_dict.remote() + after = tree_rep["tenant_char_count"]["tenant_1"] + + assert evicted == 3 assert before - after == evicted + assert "tenant_1" in tree_rep["tenants"] + + # 2. Test eviction of non-existent tenant raises ValueError + await tree.reset.remote() + with pytest.raises(ValueError): + await tree.evict_tenant_by_LRU.remote("nonexistent_tenant", 5) + + # 3. Test eviction of tenant with insufficient characters raises ValueError + await tree.reset.remote() + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("xyz", "tenant_2") + with pytest.raises(ValueError): + await tree.evict_tenant_by_LRU.remote("tenant_2", 4) + + # 4. Test eviction of all tenant data + await tree.reset.remote() + await tree.add_tenant.remote("tenant_2") + await tree.insert.remote("xyz", "tenant_2") + + tree_rep = await tree.to_dict.remote() + total_size = tree_rep["tenant_char_count"]["tenant_2"] + + evicted = await tree.evict_tenant_by_LRU.remote("tenant_2", total_size) + assert evicted == total_size + + tree_rep = await tree.to_dict.remote() + assert "tenant_2" in tree_rep["tenants"] @pytest.mark.asyncio async def test_get_smallest_tenant(): + """Test the get_smallest_tenant functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) - await tree.insert.remote("aaaa", "A") - await tree.insert.remote("bb", "B") + + # 1. Test with empty tree + await tree.reset.remote() smallest = await tree.get_smallest_tenant.remote() - assert smallest == "B" + assert smallest is None + + # 2. Test with multiple tenants of different sizes + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.add_tenant.remote("tenant_3") + await tree.insert.remote("aaaa", "tenant_1") + await tree.insert.remote("bb", "tenant_2") + await tree.insert.remote("c", "tenant_3") + + smallest = await tree.get_smallest_tenant.remote() + assert smallest == "tenant_3" + + # 3. Test after removing the smallest tenant + await tree.reset.remote() + await tree.add_tenant.remote("tenant_1") + await tree.add_tenant.remote("tenant_2") + await tree.add_tenant.remote("tenant_3") + await tree.insert.remote("aaaa", "tenant_1") + await tree.insert.remote("bb", "tenant_2") + await tree.insert.remote("c", "tenant_3") + await tree.remove_tenant.remote("tenant_3") + smallest = await tree.get_smallest_tenant.remote() + assert smallest == "tenant_2" if __name__ == "__main__": import sys - sys.exit(pytest.main(["-v", __file__])) From f1f6e95fcdcf14e8f92fe9c5e9ba79fe62b37bbd Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Fri, 2 May 2025 18:09:55 -0700 Subject: [PATCH 05/15] Linting Signed-off-by: Justin Ji --- .../serve/deployments/routers/prefix_tree.py | 48 ++++---- .../serve/cpu/deployments/test_prefix_tree.py | 103 +++++++++--------- 2 files changed, 79 insertions(+), 72 deletions(-) diff --git a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py index a5de6674aacf9..60c76b2f00ab5 100644 --- a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py +++ b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py @@ -1,6 +1,5 @@ from ray import serve import time -from collections import defaultdict from threading import RLock from typing import Optional, List, Tuple, Dict, Set @@ -55,7 +54,7 @@ def to_dict(self) -> Dict: "root": self.root, "tenants": self.tenants, "tenant_char_count": self.tenant_char_count, - "tenant_nodes": self.tenant_nodes + "tenant_nodes": self.tenant_nodes, } def to_string(self) -> str: @@ -121,12 +120,14 @@ def shared_prefix_count(a: str, b: str) -> int: # self.tenant_nodes[tenant].add(curr_node) # return curr_node - + def insert(self, text: str, tenant: str) -> Node: """Insert text into tree with given tenant. Returns the node that was inserted (or the existing node if it was updated).""" with self.lock: if tenant not in self.tenants: - raise ValueError(f"Cannot insert text for tenant '{tenant}': tenant does not exist") + raise ValueError( + f"Cannot insert text for tenant '{tenant}': tenant does not exist" + ) curr_node: Node = self.root timestamp_ms: int = int(time.time() * 1000) @@ -213,6 +214,7 @@ def insert(self, text: str, tenant: str) -> Node: curr_node.tenant_last_access_time[tenant] = timestamp_ms self.tenant_nodes[tenant].add(curr_node) return curr_node + def prefix_match( self, text: str, available_tenants: Optional[List[str]] = None ) -> Tuple[str, Optional[List[str]]]: @@ -224,8 +226,7 @@ def prefix_match( if available_tenants: # Filter available_tenants to only include those that exist in the tree available_tenants = [ - tenant for tenant in available_tenants - if tenant in self.tenants + tenant for tenant in available_tenants if tenant in self.tenants ] if not available_tenants: return "", None @@ -277,7 +278,6 @@ def prefix_match( ret_text: str = text[:i] return ret_text, selected_tenants - def add_tenant(self, tenant: str) -> None: """Add a tenant to the tree.""" with self.lock: @@ -288,30 +288,37 @@ def add_tenant(self, tenant: str) -> None: self.tenant_char_count[tenant] = 0 self.tenant_nodes[tenant] = set() - def remove_tenant(self, tenant: str) -> int: """Remove a tenant's nodes from the tree, returns the number of characters removed. Also removes the tenant from tenants, tenant_char_count, and tenant_nodes.""" with self.lock: if tenant not in self.tenants: - raise ValueError(f"Cannot remove tenant '{tenant}': tenant does not exist") + raise ValueError( + f"Cannot remove tenant '{tenant}': tenant does not exist" + ) total_chars_removed: int = 0 for node in self.tenant_nodes[tenant].copy(): total_chars_removed += self.remove_tenant_single_node(tenant, node) - + self.tenants.remove(tenant) self.tenant_nodes.pop(tenant, None) self.tenant_char_count.pop(tenant, None) return total_chars_removed - def remove_tenant_single_node(self, tenant: str, node: Node) -> int: """Remove a single node belonging to a tenant, returns the number of characters removed.""" with self.lock: if tenant not in self.tenants: - raise ValueError(f"Cannot remove tenant '{tenant}': tenant does not exist") - if node not in self.tenant_nodes[tenant] or tenant not in node.tenant_last_access_time: - raise ValueError(f"Cannot remove node '{node.text}' from tenant '{tenant}': tenant does not have this node") + raise ValueError( + f"Cannot remove tenant '{tenant}': tenant does not exist" + ) + if ( + node not in self.tenant_nodes[tenant] + or tenant not in node.tenant_last_access_time + ): + raise ValueError( + f"Cannot remove node '{node.text}' from tenant '{tenant}': tenant does not have this node" + ) removed_chars_len: int = len(node.text) self.tenant_char_count[tenant] -= removed_chars_len @@ -323,7 +330,6 @@ def remove_tenant_single_node(self, tenant: str, node: Node) -> int: return removed_chars_len - def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int: """Evict nodes from a tenant until the removed character count is at least min_remove_size. @@ -336,10 +342,14 @@ def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int: """ with self.lock: if tenant not in self.tenant_nodes or not self.tenant_nodes[tenant]: - raise ValueError(f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes") + raise ValueError( + f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes" + ) if self.tenant_char_count[tenant] < min_remove_size: - raise ValueError(f"Cannot evict tenant '{tenant}': total character count is less than min_remove_size") + raise ValueError( + f"Cannot evict tenant '{tenant}': total character count is less than min_remove_size" + ) # Sort nodes by last access time (oldest first) nodes_to_evict = sorted( @@ -347,7 +357,6 @@ def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int: key=lambda node: node.tenant_last_access_time.get(tenant, 0), ) - total_chars_removed: int = 0 # Remove nodes until we've reached the minimum removal size @@ -361,11 +370,10 @@ def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int: return total_chars_removed - def get_smallest_tenant(self) -> Optional[str]: """Get the tenant with the smallest total character count.""" with self.lock: if not self.tenant_char_count: return None - return min(self.tenant_char_count.items(), key=lambda x: x[1])[0] \ No newline at end of file + return min(self.tenant_char_count.items(), key=lambda x: x[1])[0] diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index e5777b73ca418..f66cd46418622 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -20,7 +20,7 @@ def serve_instance(): async def test_add_tenant(): """Test adding tenants to the tree.""" tree = serve.run(PrefixTree.bind()) - + # 1. Test basic tenant addition await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -28,7 +28,7 @@ async def test_add_tenant(): assert "tenant_1" in tree_rep["tenants"] assert tree_rep["tenant_char_count"]["tenant_1"] == 0 assert tree_rep["tenant_nodes"]["tenant_1"] == set() - + # 2. Test adding duplicate tenant raises ValueError await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -40,7 +40,7 @@ async def test_add_tenant(): async def test_insert(): """Test the insert functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) - + # 1. Test basic insertion await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -48,11 +48,11 @@ async def test_insert(): matched_text, tenants = await tree.prefix_match.remote("hello") assert matched_text == "hello" assert tenants == ["tenant_1"] - + tree_rep = await tree.to_dict.remote() assert tree_rep["tenant_char_count"]["tenant_1"] == 5 assert len(tree_rep["tenant_nodes"]["tenant_1"]) == 2 - + # 2. Test duplicate insertion doesn't double count await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -60,7 +60,7 @@ async def test_insert(): await tree.insert.remote("foo", "tenant_1") await tree.insert.remote("foo", "tenant_1") # duplicate await tree.insert.remote("bar", "tenant_2") - + tree_rep = await tree.to_dict.remote() assert tree_rep["tenant_char_count"]["tenant_1"] == 3 assert tree_rep["tenant_char_count"]["tenant_2"] == 3 @@ -71,7 +71,7 @@ async def test_insert(): await tree.add_tenant.remote("tenant_2") await tree.insert.remote("helloworld", "tenant_1") await tree.insert.remote("hellothere", "tenant_2") - + tree_rep = await tree.to_dict.remote() root = tree_rep["root"] h_node = root.children.get("h") @@ -79,7 +79,7 @@ async def test_insert(): assert h_node.text == "hello" assert h_node.children.get("w").text == "world" assert h_node.children.get("t").text == "there" - + # 4. Test inserting for non-existent tenant raises ValueError await tree.reset.remote() with pytest.raises(ValueError): @@ -90,13 +90,13 @@ async def test_insert(): async def test_prefix_match(): """Test the prefix_match functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) - + # 1. Test no match await tree.reset.remote() matched_text, tenants = await tree.prefix_match.remote("hello") assert matched_text == "" assert tenants is None - + # 2. Test match with non-existing prefix returns empty string and all tenants await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -108,7 +108,7 @@ async def test_prefix_match(): assert len(tenants) == 2 assert "tenant_1" in tenants assert "tenant_2" in tenants - + # 3. Test exact match await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -116,8 +116,7 @@ async def test_prefix_match(): matched_text, tenants = await tree.prefix_match.remote("hello") assert matched_text == "hello" assert tenants == ["tenant_1"] - - + # 4. Test partial match await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -127,7 +126,7 @@ async def test_prefix_match(): text, tenants = await tree.prefix_match.remote("application") assert text == "appl" assert tenants == ["tenant_1"] - + # 5. Test match by tenant await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -137,7 +136,7 @@ async def test_prefix_match(): text, tenants = await tree.prefix_match.remote("application", ["tenant_2"]) assert text == "ap" assert tenants == ["tenant_2"] - + # 6. Test match by non-existent tenant await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -147,7 +146,7 @@ async def test_prefix_match(): text, tenants = await tree.prefix_match.remote("application", ["tenant_3"]) assert text == "" assert tenants is None - + # 7. Test shared prefix matching with branches await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -166,19 +165,19 @@ async def test_prefix_match(): async def test_remove_tenant(): """Test removing a tenant from the tree.""" tree = serve.run(PrefixTree.bind()) - + # 1. Test basic tenant removal await tree.reset.remote() await tree.add_tenant.remote("tenant_1") await tree.insert.remote("hello", "tenant_1") removed = await tree.remove_tenant.remote("tenant_1") assert removed == 5 - + tree_rep = await tree.to_dict.remote() assert "tenant_1" not in tree_rep["tenants"] assert "tenant_1" not in tree_rep["tenant_char_count"] assert "tenant_1" not in tree_rep["tenant_nodes"] - + # 2. Test removing tenant with multiple nodes await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -186,12 +185,12 @@ async def test_remove_tenant(): await tree.insert.remote("dog", "tenant_1") removed = await tree.remove_tenant.remote("tenant_1") assert removed == len("cat") + len("dog") - + # 3. Test removing non-existent tenant raises ValueError await tree.reset.remote() with pytest.raises(ValueError): await tree.remove_tenant.remote("nonexistent_tenant") - + # 4. Test tree structure after removing tenant await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -201,38 +200,37 @@ async def test_remove_tenant(): # Remove tenant_1, verify tenant_2 still works await tree.remove_tenant.remote("tenant_1") - + tree_rep = await tree.to_dict.remote() assert "tenant_1" not in tree_rep["tenants"] assert "tenant_2" in tree_rep["tenants"] - + matched_text, tenants = await tree.prefix_match.remote("hello") assert matched_text == "hello" assert tenants == ["tenant_2"] - + # 5. Test removing the last tenant from a node removes the node await tree.reset.remote() await tree.add_tenant.remote("tenant_1") await tree.add_tenant.remote("tenant_2") await tree.insert.remote("unique1", "tenant_1") await tree.insert.remote("unique2", "tenant_2") - + # Remove tenant_1 await tree.remove_tenant.remote("tenant_1") - + tree_rep = await tree.to_dict.remote() root = tree_rep["root"] # 'u' node should only have one child now ('2' from unique2) - assert 'u' in root.children - assert '2' in root.children['u'].children # '2' from unique2 - assert len(root.children['u'].children) == 1 - + assert "u" in root.children + assert "2" in root.children["u"].children # '2' from unique2 + assert len(root.children["u"].children) == 1 + @pytest.mark.asyncio async def test_remove_tenant_single_node(): """Test removing a single node for a tenant.""" tree = serve.run(PrefixTree.bind()) - # # 1. Test removing a single node # TEST FAILS: Ray creates new node instances when making remote calls? @@ -241,36 +239,36 @@ async def test_remove_tenant_single_node(): # await tree.reset.remote() # await tree.add_tenant.remote("tenant_1") # h_node = await tree.insert.remote("hello", "tenant_1") - + # removed = await tree.remove_tenant_single_node.remote("tenant_1", h_node) # assert removed == 5 - + # tree_rep = await tree.to_dict.remote() # assert tree_rep["tenant_char_count"]["tenant_1"] == 0 # assert tree_rep["tenant_nodes"]["tenant_1"] == set() - + # 2. Test removing node for non-existent tenant raises ValueError await tree.reset.remote() await tree.add_tenant.remote("tenant_1") await tree.insert.remote("hello", "tenant_1") - + tree_rep = await tree.to_dict.remote() root = tree_rep["root"] h_node = root.children.get("h") - + with pytest.raises(ValueError): await tree.remove_tenant_single_node.remote("nonexistent_tenant", h_node) - + # 3. Test removing node that doesn't belong to tenant raises ValueError await tree.reset.remote() await tree.add_tenant.remote("tenant_1") await tree.add_tenant.remote("tenant_2") await tree.insert.remote("hello", "tenant_1") - + tree_rep = await tree.to_dict.remote() root = tree_rep["root"] h_node = root.children.get("h") - + with pytest.raises(ValueError): await tree.remove_tenant_single_node.remote("tenant_2", h_node) @@ -279,7 +277,7 @@ async def test_remove_tenant_single_node(): async def test_evict_tenant_by_LRU(): """Test the evict_tenant_by_LRU functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) - + # 1. Test eviction with LRU ordering await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -288,19 +286,19 @@ async def test_evict_tenant_by_LRU(): await tree.insert.remote("bb", "tenant_1") time.sleep(0.001) await tree.insert.remote("ccc", "tenant_1") - + tree_rep = await tree.to_dict.remote() before = tree_rep["tenant_char_count"]["tenant_1"] - + evicted = await tree.evict_tenant_by_LRU.remote("tenant_1", 2) - + tree_rep = await tree.to_dict.remote() after = tree_rep["tenant_char_count"]["tenant_1"] - + assert evicted == 3 assert before - after == evicted assert "tenant_1" in tree_rep["tenants"] - + # 2. Test eviction of non-existent tenant raises ValueError await tree.reset.remote() with pytest.raises(ValueError): @@ -317,13 +315,13 @@ async def test_evict_tenant_by_LRU(): await tree.reset.remote() await tree.add_tenant.remote("tenant_2") await tree.insert.remote("xyz", "tenant_2") - + tree_rep = await tree.to_dict.remote() total_size = tree_rep["tenant_char_count"]["tenant_2"] - + evicted = await tree.evict_tenant_by_LRU.remote("tenant_2", total_size) assert evicted == total_size - + tree_rep = await tree.to_dict.remote() assert "tenant_2" in tree_rep["tenants"] @@ -332,12 +330,12 @@ async def test_evict_tenant_by_LRU(): async def test_get_smallest_tenant(): """Test the get_smallest_tenant functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) - + # 1. Test with empty tree await tree.reset.remote() smallest = await tree.get_smallest_tenant.remote() assert smallest is None - + # 2. Test with multiple tenants of different sizes await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -346,10 +344,10 @@ async def test_get_smallest_tenant(): await tree.insert.remote("aaaa", "tenant_1") await tree.insert.remote("bb", "tenant_2") await tree.insert.remote("c", "tenant_3") - + smallest = await tree.get_smallest_tenant.remote() assert smallest == "tenant_3" - + # 3. Test after removing the smallest tenant await tree.reset.remote() await tree.add_tenant.remote("tenant_1") @@ -365,4 +363,5 @@ async def test_get_smallest_tenant(): if __name__ == "__main__": import sys + sys.exit(pytest.main(["-v", __file__])) From ee9568df4a8942ab505beebab19a81a52de38e55 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Mon, 5 May 2025 18:10:21 -0700 Subject: [PATCH 06/15] Address comments Signed-off-by: Justin Ji --- .../serve/deployments/routers/prefix_tree.py | 379 -------------- .../prefix_aware/prefix_tree.py | 486 ++++++++++++++++++ .../serve/cpu/deployments/test_prefix_tree.py | 168 +++--- 3 files changed, 562 insertions(+), 471 deletions(-) delete mode 100644 python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py create mode 100644 python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py diff --git a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py b/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py deleted file mode 100644 index 60c76b2f00ab5..0000000000000 --- a/python/ray/llm/_internal/serve/deployments/routers/prefix_tree.py +++ /dev/null @@ -1,379 +0,0 @@ -from ray import serve -import time -from threading import RLock -from typing import Optional, List, Tuple, Dict, Set - - -class Node: - """ - Node in a prefix tree that tracks tenant access time. - - Each node represents a segment of text and can belong to multiple tenants. - """ - - def __init__(self, text: str = "", parent: Optional["Node"] = None): - self.text: str = text - self.parent: Optional["Node"] = parent - self.children: Dict[str, "Node"] = {} # Maps char -> Node - self.tenant_last_access_time: Dict[ - str, int - ] = {} # Maps tenant -> timestamp in ms (int) - - def to_string(self) -> str: - return f"Node(text='{self.text}', parent={self.parent}, children={self.children}, tenant_last_access_time={self.tenant_last_access_time})" - - -@serve.deployment(name="TreeDeployment") -class PrefixTree: - """ - Thread-safe multi-tenant prefix tree (approximate radix tree). - - Features: - 1. Stores data for multiple tenants in the same tree structure - 2. Node-level locking for concurrent access - 3. Leaf LRU eviction based on tenant access time - """ - - def __init__(self) -> None: - self.lock: RLock = RLock() - self.root: Node = Node() - self.tenants: Set[str] = set() - self.tenant_char_count: Dict[str, int] = {} - self.tenant_nodes: Dict[str, Set[Node]] = {} - - def reset(self) -> None: - """Reset the tree to an empty state.""" - with self.lock: - self.root = Node() - self.tenants = set() - self.tenant_char_count = {} - self.tenant_nodes = {} - - def to_dict(self) -> Dict: - return { - "root": self.root, - "tenants": self.tenants, - "tenant_char_count": self.tenant_char_count, - "tenant_nodes": self.tenant_nodes, - } - - def to_string(self) -> str: - return f"PrefixTree(root={self.root.__str__()}, tenants={self.tenants}, tenant_char_count={self.tenant_char_count}, tenant_nodes={self.tenant_nodes})" - - @staticmethod - def shared_prefix_count(a: str, b: str) -> int: - """Count the number of shared characters at the beginning of two strings.""" - i: int = 0 - for char_a, char_b in zip(a, b): - if char_a == char_b: - i += 1 - else: - break - return i - - # def insert(self, text: str, tenant: str) -> Node: - # """Insert text into tree with given tenant. Returns the node that was inserted (or the existing node if it was updated).""" - # with self.lock: - # if tenant not in self.tenants: - # raise ValueError(f"Cannot insert text for tenant '{tenant}': tenant does not exist") - - # curr_node: Node = self.root - # timestamp_ms: int = int(time.time() * 1000) - # i: int = 0 - # while i < len(text): - # self.tenant_nodes[tenant].add(curr_node) - - # first_char: str = text[i] - # curr_text: str = text[i:] - # if first_char not in curr_node.children: - # # No match, create new node - # # e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")} - # new_node: Node = Node(text=curr_text, parent=curr_node) - # new_node.tenant_last_access_time[tenant] = timestamp_ms - # curr_node.children[first_char] = new_node - - # # Match found, check if need to split - # matched_node: Node = curr_node.children[first_char] - # shared_count: int = self.shared_prefix_count( - # matched_node.text, curr_text - # ) - # if shared_count == len(matched_node.text): - # # Full match, move down the tree - # if tenant not in matched_node.tenant_last_access_time: - # self.tenant_char_count[tenant] += shared_count - # matched_node.tenant_last_access_time[tenant] = timestamp_ms - # curr_node = matched_node - # else: - # # Partial match, split at matched point - # matched_text: str = matched_node.text[:shared_count] - # remaining_text: str = matched_node.text[shared_count:] - # new_parent: Node = Node(text=matched_text, parent=curr_node) - # matched_node.text = remaining_text - # matched_node.parent = new_parent - # new_parent.children[remaining_text[0]] = matched_node - # if tenant not in new_parent.tenant_last_access_time: - # self.tenant_char_count[tenant] += shared_count - # new_parent.tenant_last_access_time[tenant] = timestamp_ms - # curr_node = new_parent - - # i += shared_count - - # self.tenant_nodes[tenant].add(curr_node) - # return curr_node - - def insert(self, text: str, tenant: str) -> Node: - """Insert text into tree with given tenant. Returns the node that was inserted (or the existing node if it was updated).""" - with self.lock: - if tenant not in self.tenants: - raise ValueError( - f"Cannot insert text for tenant '{tenant}': tenant does not exist" - ) - - curr_node: Node = self.root - timestamp_ms: int = int(time.time() * 1000) - i: int = 0 - while i < len(text): - curr_node.tenant_last_access_time[tenant] = timestamp_ms - self.tenant_nodes[tenant].add(curr_node) - - first_char: str = text[i] - curr_text: str = text[i:] - if first_char not in curr_node.children: - # No match, create new node - # e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")} - new_node: Node = Node(text=curr_text, parent=curr_node) - new_node.tenant_last_access_time[tenant] = timestamp_ms - - # Increment char count for tenant and add node to tenant_nodes - self.tenant_char_count[tenant] += len(curr_text) - self.tenant_nodes[tenant].add(new_node) - - curr_node.children[first_char] = new_node - else: - # Match found, check if need to split - matched_node: Node = curr_node.children[first_char] - shared_count: int = self.shared_prefix_count( - matched_node.text, curr_text - ) - - if shared_count < len(matched_node.text): - # Partial match, split at matched point - # Example: - ## Before update: - ### curr_node.children = {"h": Node("helloworld")}, curr_text = "hellothere" -> shared_count = 5 - ### matched_node = Node("helloworld") - - ## During update: - ### Increment tenant_char_count[tenant] by shared_count if matched_node has not seen this tenant before - - ## After update: - ### curr_node.children = {"h": Node("hello", children = {"w": Node("world")})} - ### parent_node = Node("hello"), matched_node = Node("world") - ### Update tenant_last_access_time for parent_node, NOT matched_node - ### (new) curr_text = "there", (new) curr_node = parent_node - ### Continue adding "there" to tree in next iteration - - matched_text: str = matched_node.text[:shared_count] - remaining_text: str = matched_node.text[shared_count:] - - # Update tenant char count for the new split node - if tenant not in matched_node.tenant_last_access_time: - self.tenant_char_count[tenant] += shared_count - - # Create new parent node - new_parent: Node = Node(text=matched_text, parent=curr_node) - new_parent.tenant_last_access_time = ( - matched_node.tenant_last_access_time.copy() - ) - # Update matched_node - matched_node.text = remaining_text - matched_node.parent = new_parent - - # Connect new parent node to matched_node - new_parent.children[remaining_text[0]] = matched_node - - # Connect current node to new parent - curr_node.children[first_char] = new_parent - - # Move down the tree - curr_node = new_parent - i += shared_count - else: - # Full match - - # Update tenant char count if this is a new tenant for this node - if tenant not in matched_node.tenant_last_access_time: - self.tenant_char_count[tenant] += shared_count - - # # Update tenant last access time - # matched_node.tenant_last_access_time[tenant] = timestamp_ms - - # Move down the tree - curr_node = matched_node - i += shared_count - curr_node.tenant_last_access_time[tenant] = timestamp_ms - self.tenant_nodes[tenant].add(curr_node) - return curr_node - - def prefix_match( - self, text: str, available_tenants: Optional[List[str]] = None - ) -> Tuple[str, Optional[List[str]]]: - """ - Match text against tree and return (matched_text, matched_tenants). - Does not update access time for the matched tenants (only updates when insert() is called). - If available_tenants is not provided, all tenants are considered. - """ - if available_tenants: - # Filter available_tenants to only include those that exist in the tree - available_tenants = [ - tenant for tenant in available_tenants if tenant in self.tenants - ] - if not available_tenants: - return "", None - else: - available_tenants = list(self.tenants) - - with self.lock: - curr_node: Node = self.root - i: int = 0 - text_len: int = len(text) - - while i < text_len: - first_char: str = text[i] - curr_text: str = text[i:] - - if first_char in curr_node.children: - matched_node: Node = curr_node.children[first_char] - - # Check if any of the available tenants match this node - if not any( - tenant in matched_node.tenant_last_access_time - for tenant in available_tenants - ): - break - - shared_count: int = self.shared_prefix_count( - matched_node.text, curr_text - ) - i += shared_count - curr_node = matched_node - - if shared_count < len(matched_node.text): - # Partial match, stop here - break - else: - # No match found, stop here - break - - # Select the tenants in available_tenants that are in the current node - selected_tenants: Optional[List[str]] = None - matching_tenants = [ - tenant - for tenant in available_tenants - if tenant in curr_node.tenant_last_access_time - ] - if matching_tenants: - selected_tenants = matching_tenants - - ret_text: str = text[:i] - return ret_text, selected_tenants - - def add_tenant(self, tenant: str) -> None: - """Add a tenant to the tree.""" - with self.lock: - if tenant in self.tenants: - raise ValueError(f"Cannot add tenant '{tenant}': tenant already exists") - - self.tenants.add(tenant) - self.tenant_char_count[tenant] = 0 - self.tenant_nodes[tenant] = set() - - def remove_tenant(self, tenant: str) -> int: - """Remove a tenant's nodes from the tree, returns the number of characters removed. Also removes the tenant from tenants, tenant_char_count, and tenant_nodes.""" - with self.lock: - if tenant not in self.tenants: - raise ValueError( - f"Cannot remove tenant '{tenant}': tenant does not exist" - ) - - total_chars_removed: int = 0 - for node in self.tenant_nodes[tenant].copy(): - total_chars_removed += self.remove_tenant_single_node(tenant, node) - - self.tenants.remove(tenant) - self.tenant_nodes.pop(tenant, None) - self.tenant_char_count.pop(tenant, None) - return total_chars_removed - - def remove_tenant_single_node(self, tenant: str, node: Node) -> int: - """Remove a single node belonging to a tenant, returns the number of characters removed.""" - with self.lock: - if tenant not in self.tenants: - raise ValueError( - f"Cannot remove tenant '{tenant}': tenant does not exist" - ) - if ( - node not in self.tenant_nodes[tenant] - or tenant not in node.tenant_last_access_time - ): - raise ValueError( - f"Cannot remove node '{node.text}' from tenant '{tenant}': tenant does not have this node" - ) - - removed_chars_len: int = len(node.text) - self.tenant_char_count[tenant] -= removed_chars_len - self.tenant_nodes[tenant].remove(node) - node.tenant_last_access_time.pop(tenant, None) - # If this node has no more tenants, remove it from the parent - if not node.tenant_last_access_time and node.parent: - node.parent.children.pop(node.text[0], None) - - return removed_chars_len - - def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int: - """Evict nodes from a tenant until the removed character count is at least min_remove_size. - - Args: - tenant: The tenant to evict nodes from - min_remove_size: Minimum number of characters to remove - - Returns: - int: The actual number of characters removed - """ - with self.lock: - if tenant not in self.tenant_nodes or not self.tenant_nodes[tenant]: - raise ValueError( - f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes" - ) - - if self.tenant_char_count[tenant] < min_remove_size: - raise ValueError( - f"Cannot evict tenant '{tenant}': total character count is less than min_remove_size" - ) - - # Sort nodes by last access time (oldest first) - nodes_to_evict = sorted( - self.tenant_nodes[tenant], - key=lambda node: node.tenant_last_access_time.get(tenant, 0), - ) - - total_chars_removed: int = 0 - - # Remove nodes until we've reached the minimum removal size - for node in nodes_to_evict.copy(): - # Use existing function to remove tenant from node - total_chars_removed += self.remove_tenant_single_node(tenant, node) - - # Check if we've removed enough characters - if total_chars_removed >= min_remove_size: - break - - return total_chars_removed - - def get_smallest_tenant(self) -> Optional[str]: - """Get the tenant with the smallest total character count.""" - with self.lock: - if not self.tenant_char_count: - return None - - return min(self.tenant_char_count.items(), key=lambda x: x[1])[0] diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py new file mode 100644 index 0000000000000..90fcb765571e1 --- /dev/null +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -0,0 +1,486 @@ +from __future__ import annotations + +import heapq +import logging +import os +from threading import RLock +from typing import Dict, List, Optional, Set, Tuple + +from ray import serve + +# Logger for this module +logger = logging.getLogger(__name__) + + +class Node: + """ + Node in a prefix tree that tracks tenant access time. + + Each node represents a segment of text and can belong to multiple tenants. + + Example tree structure: + Representing the strings inserted in order: + - "helloworld" at time 1 by tenant_1 + - "hellothere" at time 2 by tenant_2 + - "hellothomas" at time 3 by tenant_2 + + root: [] + {tenant_1: 1, tenant_2: 3} + | + (h)| + | + [hello] + {tenant_1: 1, tenant_2: 3} + / \ + (w)/ \\(t) + / \ + [world] [th] + {tenant_1: 1} {tenant_2: 3} + / \ + (e)/ \\(o) + / \ + [ere] [omas] + {tenant_2: 2} {tenant_2: 3} + + Legend for each node: + - [text] = Node.text + - {tenant, timestamp} = Node.tenant_last_access_time + - (x) = edge label (first character used as key for parent's children) + """ + + def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: + """ + Initialize a node in the prefix tree. + + Args: + text: The text segment this node represents + parent: The parent node of this node + """ + self.text: str = text + self.parent: Optional[Node] = parent + self.children: Dict[str, Node] = {} # Maps first character to child node + self.tenant_last_access_time: Dict[str, int] = ( + {} + ) # Maps tenant ID to last access timestamp (in milliseconds) + + def __repr__(self) -> str: + return f"Node(text='{self.text}', children={list(self.children.keys())}, tenants={list(self.tenant_last_access_time.keys())})" + + +class TenantHeapNode: + """ + Wrapper class for storing nodes in a min-heap, ordered by tenant access time. + Used for efficient LRU eviction of tenant nodes. + """ + + def __init__(self, node: Node, tenant: str) -> None: + """ + Initialize a heap node for efficient LRU tenant management. + + Args: + node: The prefix tree node this heap node refers to + tenant: The tenant ID this heap node is associated with + """ + self.node = node + self.tenant_ordering_key = tenant + + def __lt__(self, other: TenantHeapNode) -> bool: + """ + Compare heap nodes based on tenant's last access time. + + Args: + other: Another TenantHeapNode to compare with + + Returns: + True if this node's tenant access time is earlier than the other's + """ + return ( + self.node.tenant_last_access_time[self.tenant_ordering_key] + < other.node.tenant_last_access_time[other.tenant_ordering_key] + ) + + def __repr__(self) -> str: + return f"TenantHeapNode(node={self.node}, tenant_ordering_key={self.tenant_ordering_key})" + + +@serve.deployment(name="TreeDeployment") +class PrefixTree: + """ + Thread-safe multi-tenant prefix tree (approximate radix tree). + + Features: + 1. Stores data for multiple tenants in the same tree structure + 2. Thread-safe with node-level locking for concurrent access + 3. LRU eviction based on tenant access time + 4. Efficient prefix matching across multiple tenants + """ + + def __init__(self) -> None: + """Initialize an empty prefix tree.""" + self.lock: RLock = RLock() + self.root: Node = Node() + self.tenants: Set[str] = set() # Set of tenant IDs + self.tenant_char_count: Dict[str, int] = ( + {} + ) # Tracks total character count per tenant + self.tenant_nodes: Dict[str, Set[Node]] = ( + {} + ) # Maps tenant ID to set of nodes belonging to that tenant + self.tenant_nodes_sorted: Dict[str, List[TenantHeapNode]] = ( + {} + ) # Maps tenant ID to heap of nodes for LRU eviction + + def reset(self) -> None: + """Reset the tree to an empty state.""" + with self.lock: + self.root = Node() + self.tenants = set() + self.tenant_char_count = {} + self.tenant_nodes = {} + self.tenant_nodes_sorted = {} + + def to_dict(self) -> Dict: + """ + Convert tree to dictionary for serialization. + + Returns: + Dictionary representation of the tree + """ + return { + "root": self.root, + "tenants": self.tenants, + "tenant_char_count": self.tenant_char_count, + "tenant_nodes": self.tenant_nodes, + "tenant_nodes_sorted": self.tenant_nodes_sorted, + } + + def to_string(self) -> str: + """String representation of the tree.""" + return f"PrefixTree(tenants={self.tenants}, tenant_char_count={self.tenant_char_count}, tenant_nodes={self.tenant_nodes}, tenant_nodes_sorted={self.tenant_nodes_sorted})" + + @staticmethod + def _shared_prefix_count(a: str, b: str) -> int: + """ + Count the number of shared characters at the beginning of two strings. + + Args: + a: First string + b: Second string + + Returns: + Number of matching characters at the beginning + """ + return len(os.path.commonprefix([a, b])) + + def insert(self, text: str, tenant: str, timestamp_ms: int) -> Node: + """ + Insert text into tree for a specific tenant. + + If the tenant doesn't exist, it will be automatically added. + + Args: + text: Text to insert + tenant: Tenant ID + timestamp_ms: Current timestamp in milliseconds + + Returns: + The node that was inserted or updated + """ + with self.lock: + if tenant not in self.tenants: + self._add_tenant(tenant) + + curr_node: Node = self.root + i: int = 0 + + while i <= len(text): + # Invariant: assume curr_node has not been visited by tenant yet + # Update tenant info for current node + if tenant not in curr_node.tenant_last_access_time: + self.tenant_nodes[tenant].add(curr_node) + self.tenant_char_count[tenant] += len(curr_node.text) + self.tenant_nodes_sorted[tenant].append( + TenantHeapNode(curr_node, tenant) + ) + + curr_node.tenant_last_access_time[tenant] = timestamp_ms + heapq.heapify(self.tenant_nodes_sorted[tenant]) + + if i == len(text): + break + + first_char: str = text[i] + curr_text: str = text[i:] + + if first_char not in curr_node.children: + # No match, create new node. Don't update new node as "visited" by tenant yet; it will be done in the code below. + # e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")} + new_node: Node = Node(text=curr_text, parent=curr_node) + curr_node.children[first_char] = new_node + + # Match found, check if we need to split + matched_node: Node = curr_node.children[first_char] + shared_count: int = self._shared_prefix_count( + matched_node.text, curr_text + ) + + if shared_count < len(matched_node.text): + # Partial match, split node at matched point + # Example: + ## Before update: + ### curr_node.children = {"h": Node("helloworld")}, curr_text = "hellothere" -> shared_count = 5 + ### matched_node = Node("helloworld") + + ## During update: + ### Increment tenant_char_count[tenant] by shared_count if matched_node has not seen this tenant before + + ## After update: + ### curr_node.children = {"h": Node("hello", children = {"w": Node("world")})} + ### parent_node = Node("hello"), matched_node = Node("world") + ### Update tenant_last_access_time for parent_node, NOT matched_node + ### (new) curr_text = "there", (new) curr_node = parent_node + ### Continue adding "there" to tree in next iteration + + matched_text: str = matched_node.text[:shared_count] + remaining_text: str = matched_node.text[shared_count:] + + # Create new intermediate node + new_parent: Node = Node(text=matched_text, parent=curr_node) + new_parent.tenant_last_access_time = ( + matched_node.tenant_last_access_time.copy() + ) + + # Update existing matched node + matched_node.text = remaining_text + matched_node.parent = new_parent + + # Connect nodes + new_parent.children[remaining_text[0]] = matched_node + curr_node.children[first_char] = new_parent + + # Continue traversal + curr_node = new_parent + i += shared_count + else: + # Full match, continue down the tree + curr_node = matched_node + i += shared_count + + return curr_node + + def prefix_match( + self, text: str, available_tenants: Optional[List[str]] = None + ) -> Tuple[str, Optional[List[str]]]: + """ + Match text against tree and return matched text and matching tenants. + + Args: + text: Text to match + available_tenants: List of tenants to match against (or None for all) + + Returns: + Tuple of (matched_text, matched_tenant_ids) + """ + if available_tenants: + # Filter available_tenants to only include those in the tree + available_tenants = [ + tenant for tenant in available_tenants if tenant in self.tenants + ] + if not available_tenants: + return "", None + else: + available_tenants = list(self.tenants) + + with self.lock: + curr_node: Node = self.root + i: int = 0 + text_len: int = len(text) + + while i < text_len: + first_char: str = text[i] + curr_text: str = text[i:] + + if first_char in curr_node.children: + matched_node: Node = curr_node.children[first_char] + + # Check if any available tenants match this node + if not any( + tenant in matched_node.tenant_last_access_time + for tenant in available_tenants + ): + break + + shared_count: int = self._shared_prefix_count( + matched_node.text, curr_text + ) + i += shared_count + curr_node = matched_node + + if shared_count < len(matched_node.text): + # Partial match, stop here + break + else: + # No match found, stop here + break + + # Find tenants in current node that match available tenants + matching_tenants = [ + tenant + for tenant in available_tenants + if tenant in curr_node.tenant_last_access_time + ] + + selected_tenants: Optional[List[str]] = ( + matching_tenants if matching_tenants else None + ) + matched_text: str = text[:i] + + return matched_text, selected_tenants + + def remove_tenant(self, tenant: str) -> int: + """ + Remove a tenant and all its nodes from the tree. + + Args: + tenant: Tenant ID to remove + + Returns: + Number of characters removed + + Raises: + ValueError: If tenant does not exist + """ + with self.lock: + if tenant not in self.tenants: + raise ValueError( + f"Cannot remove tenant '{tenant}': tenant does not exist" + ) + + total_chars_removed: int = 0 + for node in self.tenant_nodes[tenant].copy(): + total_chars_removed += self._remove_tenant_single_node(tenant, node) + + self.tenants.remove(tenant) + self.tenant_nodes.pop(tenant, None) + self.tenant_char_count.pop(tenant, None) + self.tenant_nodes_sorted.pop(tenant, None) + + return total_chars_removed + + def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: + """ + Remove a tenant from a single node. + + Args: + tenant: Tenant ID to remove + node: Node to remove tenant from + + Returns: + Number of characters removed + + Raises: + ValueError: If tenant does not exist or node doesn't belong to tenant + """ + with self.lock: + if tenant not in self.tenants: + raise ValueError( + f"Cannot remove tenant '{tenant}': tenant does not exist" + ) + + if ( + node not in self.tenant_nodes[tenant] + or tenant not in node.tenant_last_access_time + ): + raise ValueError( + f"Cannot remove node '{node.text}' from tenant '{tenant}': " + f"tenant does not have this node" + ) + + removed_chars_len: int = len(node.text) + self.tenant_char_count[tenant] -= removed_chars_len + self.tenant_nodes[tenant].remove(node) + node.tenant_last_access_time.pop(tenant, None) + + # Clean up empty nodes + if not node.tenant_last_access_time and node.parent: + if ( + node.text and node.text[0] in node.parent.children + ): # Defensive check + node.parent.children.pop(node.text[0], None) + + return removed_chars_len + + def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int: + """ + Evict least recently used nodes for a tenant until minimum size is freed. + + Args: + tenant: The tenant to evict nodes from + min_remove_size: Minimum number of characters to remove + + Returns: + Actual number of characters removed + + Raises: + ValueError: If tenant doesn't exist or has insufficient nodes + """ + with self.lock: + if tenant not in self.tenant_nodes or not self.tenant_nodes[tenant]: + raise ValueError( + f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes" + ) + + if self.tenant_char_count[tenant] < min_remove_size: + raise ValueError( + f"Cannot evict tenant '{tenant}': total character count " + f"({self.tenant_char_count[tenant]}) is less than min_remove_size " + f"({min_remove_size})" + ) + + total_chars_removed: int = 0 + + # Directly use the tenant's priority queue + while ( + total_chars_removed < min_remove_size + and self.tenant_nodes_sorted[tenant] + ): + heap_node: TenantHeapNode = heapq.heappop( + self.tenant_nodes_sorted[tenant] + ) + total_chars_removed += self._remove_tenant_single_node( + tenant, heap_node.node + ) + + return total_chars_removed + + def get_smallest_tenant(self) -> Optional[str]: + """ + Get the tenant with the smallest total character count. + + Returns: + Tenant ID with smallest character count, or None if no tenants + """ + with self.lock: + if not self.tenant_char_count: + return None + + return min(self.tenant_char_count, key=self.tenant_char_count.get, default=None) + + def _add_tenant(self, tenant: str) -> None: + """ + Add a new tenant to the tree. + + If the tenant already exists, this is a no-op with a warning log. + + Args: + tenant: Tenant ID to add + """ + with self.lock: + if tenant in self.tenants: + logger.warning(f"Tenant '{tenant}' already exists. No action taken.") + return + + self.tenants.add(tenant) + self.tenant_char_count[tenant] = 0 + self.tenant_nodes[tenant] = set() + self.tenant_nodes_sorted[tenant] = [] diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index f66cd46418622..f9214546af86b 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -3,7 +3,9 @@ import ray from ray import serve -from ray.llm._internal.serve.deployments.routers.prefix_tree import PrefixTree +from ray.llm._internal.serve.replica_scheduler.prefix_aware.prefix_tree import ( + PrefixTree, +) @pytest.fixture(scope="module", autouse=True) @@ -18,22 +20,25 @@ def serve_instance(): @pytest.mark.asyncio async def test_add_tenant(): - """Test adding tenants to the tree.""" + """Test adding tenants to the tree via the private _add_tenant method.""" tree = serve.run(PrefixTree.bind()) # 1. Test basic tenant addition await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") + await tree._add_tenant.remote("tenant_1") tree_rep = await tree.to_dict.remote() assert "tenant_1" in tree_rep["tenants"] assert tree_rep["tenant_char_count"]["tenant_1"] == 0 assert tree_rep["tenant_nodes"]["tenant_1"] == set() - # 2. Test adding duplicate tenant raises ValueError + # 2. Test adding duplicate tenant logs warning but doesn't raise error await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - with pytest.raises(ValueError): - await tree.add_tenant.remote("tenant_1") + await tree._add_tenant.remote("tenant_1") + # This should not raise an error now + await tree._add_tenant.remote("tenant_1") + # Verify the tenant still exists + tree_rep = await tree.to_dict.remote() + assert "tenant_1" in tree_rep["tenants"] @pytest.mark.asyncio @@ -43,8 +48,8 @@ async def test_insert(): # 1. Test basic insertion await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.insert.remote("hello", "tenant_1") + # No need to call add_tenant first - insert will do it automatically + await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) matched_text, tenants = await tree.prefix_match.remote("hello") assert matched_text == "hello" assert tenants == ["tenant_1"] @@ -55,11 +60,10 @@ async def test_insert(): # 2. Test duplicate insertion doesn't double count await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("foo", "tenant_1") - await tree.insert.remote("foo", "tenant_1") # duplicate - await tree.insert.remote("bar", "tenant_2") + # Insert automatically adds tenants + await tree.insert.remote("foo", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("foo", "tenant_1", int(time.time() * 1000)) # duplicate + await tree.insert.remote("bar", "tenant_2", int(time.time() * 1000)) tree_rep = await tree.to_dict.remote() assert tree_rep["tenant_char_count"]["tenant_1"] == 3 @@ -67,10 +71,8 @@ async def test_insert(): # 3. Test node splitting on partial match await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("helloworld", "tenant_1") - await tree.insert.remote("hellothere", "tenant_2") + await tree.insert.remote("helloworld", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("hellothere", "tenant_2", int(time.time() * 1000)) tree_rep = await tree.to_dict.remote() root = tree_rep["root"] @@ -80,10 +82,15 @@ async def test_insert(): assert h_node.children.get("w").text == "world" assert h_node.children.get("t").text == "there" - # 4. Test inserting for non-existent tenant raises ValueError + # 4. Test inserting for non-existent tenant automatically adds the tenant await tree.reset.remote() - with pytest.raises(ValueError): - await tree.insert.remote("hello", "nonexistent_tenant") + # This should not raise an error now + await tree.insert.remote("hello", "nonexistent_tenant", int(time.time() * 1000)) + + # Verify the tenant was added + tree_rep = await tree.to_dict.remote() + assert "nonexistent_tenant" in tree_rep["tenants"] + assert tree_rep["tenant_char_count"]["nonexistent_tenant"] == 5 @pytest.mark.asyncio @@ -91,18 +98,16 @@ async def test_prefix_match(): """Test the prefix_match functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) - # 1. Test no match - await tree.reset.remote() - matched_text, tenants = await tree.prefix_match.remote("hello") - assert matched_text == "" - assert tenants is None + # # 1. Test no match + # await tree.reset.remote() + # matched_text, tenants = await tree.prefix_match.remote("hello") + # assert matched_text == "" + # assert tenants is None # 2. Test match with non-existing prefix returns empty string and all tenants await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("hello", "tenant_1") - await tree.insert.remote("hellothere", "tenant_2") + await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("hellothere", "tenant_2", int(time.time() * 1000)) matched_text, tenants = await tree.prefix_match.remote("foobar") assert matched_text == "" assert len(tenants) == 2 @@ -111,48 +116,39 @@ async def test_prefix_match(): # 3. Test exact match await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.insert.remote("hello", "tenant_1") + await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) matched_text, tenants = await tree.prefix_match.remote("hello") assert matched_text == "hello" assert tenants == ["tenant_1"] # 4. Test partial match await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("apple", "tenant_1") - await tree.insert.remote("apricot", "tenant_2") + await tree.insert.remote("apple", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("apricot", "tenant_2", int(time.time() * 1000)) text, tenants = await tree.prefix_match.remote("application") assert text == "appl" assert tenants == ["tenant_1"] # 5. Test match by tenant await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("apple", "tenant_1") - await tree.insert.remote("apricot", "tenant_2") + await tree.insert.remote("apple", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("apricot", "tenant_2", int(time.time() * 1000)) text, tenants = await tree.prefix_match.remote("application", ["tenant_2"]) assert text == "ap" assert tenants == ["tenant_2"] # 6. Test match by non-existent tenant await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("apple", "tenant_1") - await tree.insert.remote("apricot", "tenant_2") + await tree.insert.remote("apple", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("apricot", "tenant_2", int(time.time() * 1000)) text, tenants = await tree.prefix_match.remote("application", ["tenant_3"]) assert text == "" assert tenants is None # 7. Test shared prefix matching with branches await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("helloworld", "tenant_1") - await tree.insert.remote("hellothere", "tenant_2") + await tree.insert.remote("helloworld", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("hellothere", "tenant_2", int(time.time() * 1000)) text_a, tenants_a = await tree.prefix_match.remote("helloworld") text_b, tenants_b = await tree.prefix_match.remote("hellothereworld") assert text_a == "helloworld" @@ -168,8 +164,7 @@ async def test_remove_tenant(): # 1. Test basic tenant removal await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.insert.remote("hello", "tenant_1") + await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) removed = await tree.remove_tenant.remote("tenant_1") assert removed == 5 @@ -180,9 +175,8 @@ async def test_remove_tenant(): # 2. Test removing tenant with multiple nodes await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.insert.remote("cat", "tenant_1") - await tree.insert.remote("dog", "tenant_1") + await tree.insert.remote("cat", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("dog", "tenant_1", int(time.time() * 1000)) removed = await tree.remove_tenant.remote("tenant_1") assert removed == len("cat") + len("dog") @@ -193,10 +187,8 @@ async def test_remove_tenant(): # 4. Test tree structure after removing tenant await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("hello", "tenant_1") - await tree.insert.remote("hello", "tenant_2") + await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("hello", "tenant_2", int(time.time() * 1000)) # Remove tenant_1, verify tenant_2 still works await tree.remove_tenant.remote("tenant_1") @@ -211,10 +203,8 @@ async def test_remove_tenant(): # 5. Test removing the last tenant from a node removes the node await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("unique1", "tenant_1") - await tree.insert.remote("unique2", "tenant_2") + await tree.insert.remote("unique1", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("unique2", "tenant_2", int(time.time() * 1000)) # Remove tenant_1 await tree.remove_tenant.remote("tenant_1") @@ -228,7 +218,7 @@ async def test_remove_tenant(): @pytest.mark.asyncio -async def test_remove_tenant_single_node(): +async def test__remove_tenant_single_node(): """Test removing a single node for a tenant.""" tree = serve.run(PrefixTree.bind()) @@ -237,10 +227,10 @@ async def test_remove_tenant_single_node(): # The node from insert.remote() is not identity-equal to the one in tenant_nodes # await tree.reset.remote() - # await tree.add_tenant.remote("tenant_1") - # h_node = await tree.insert.remote("hello", "tenant_1") + # await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) + # h_node = await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) - # removed = await tree.remove_tenant_single_node.remote("tenant_1", h_node) + # removed = await tree._remove_tenant_single_node.remote("tenant_1", h_node) # assert removed == 5 # tree_rep = await tree.to_dict.remote() @@ -249,28 +239,26 @@ async def test_remove_tenant_single_node(): # 2. Test removing node for non-existent tenant raises ValueError await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.insert.remote("hello", "tenant_1") + await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) tree_rep = await tree.to_dict.remote() root = tree_rep["root"] h_node = root.children.get("h") with pytest.raises(ValueError): - await tree.remove_tenant_single_node.remote("nonexistent_tenant", h_node) + await tree._remove_tenant_single_node.remote("nonexistent_tenant", h_node) # 3. Test removing node that doesn't belong to tenant raises ValueError await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("hello", "tenant_1") + await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) + await tree.insert.remote("world", "tenant_2", int(time.time() * 1000)) tree_rep = await tree.to_dict.remote() root = tree_rep["root"] h_node = root.children.get("h") with pytest.raises(ValueError): - await tree.remove_tenant_single_node.remote("tenant_2", h_node) + await tree._remove_tenant_single_node.remote("tenant_2", h_node) @pytest.mark.asyncio @@ -280,12 +268,14 @@ async def test_evict_tenant_by_LRU(): # 1. Test eviction with LRU ordering await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.insert.remote("a", "tenant_1") + current_time = int(time.time() * 1000) + await tree.insert.remote("a", "tenant_1", current_time) time.sleep(0.001) - await tree.insert.remote("bb", "tenant_1") + current_time = int(time.time() * 1000) + await tree.insert.remote("bb", "tenant_1", current_time) time.sleep(0.001) - await tree.insert.remote("ccc", "tenant_1") + current_time = int(time.time() * 1000) + await tree.insert.remote("ccc", "tenant_1", current_time) tree_rep = await tree.to_dict.remote() before = tree_rep["tenant_char_count"]["tenant_1"] @@ -306,15 +296,13 @@ async def test_evict_tenant_by_LRU(): # 3. Test eviction of tenant with insufficient characters raises ValueError await tree.reset.remote() - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("xyz", "tenant_2") + await tree.insert.remote("xyz", "tenant_2", int(time.time() * 1000)) with pytest.raises(ValueError): await tree.evict_tenant_by_LRU.remote("tenant_2", 4) # 4. Test eviction of all tenant data await tree.reset.remote() - await tree.add_tenant.remote("tenant_2") - await tree.insert.remote("xyz", "tenant_2") + await tree.insert.remote("xyz", "tenant_2", int(time.time() * 1000)) tree_rep = await tree.to_dict.remote() total_size = tree_rep["tenant_char_count"]["tenant_2"] @@ -338,24 +326,20 @@ async def test_get_smallest_tenant(): # 2. Test with multiple tenants of different sizes await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.add_tenant.remote("tenant_3") - await tree.insert.remote("aaaa", "tenant_1") - await tree.insert.remote("bb", "tenant_2") - await tree.insert.remote("c", "tenant_3") + current_time = int(time.time() * 1000) + await tree.insert.remote("aaaa", "tenant_1", current_time) + await tree.insert.remote("bb", "tenant_2", current_time) + await tree.insert.remote("c", "tenant_3", current_time) smallest = await tree.get_smallest_tenant.remote() assert smallest == "tenant_3" # 3. Test after removing the smallest tenant await tree.reset.remote() - await tree.add_tenant.remote("tenant_1") - await tree.add_tenant.remote("tenant_2") - await tree.add_tenant.remote("tenant_3") - await tree.insert.remote("aaaa", "tenant_1") - await tree.insert.remote("bb", "tenant_2") - await tree.insert.remote("c", "tenant_3") + current_time = int(time.time() * 1000) + await tree.insert.remote("aaaa", "tenant_1", current_time) + await tree.insert.remote("bb", "tenant_2", current_time) + await tree.insert.remote("c", "tenant_3", current_time) await tree.remove_tenant.remote("tenant_3") smallest = await tree.get_smallest_tenant.remote() assert smallest == "tenant_2" From 9c110bbabfa4b07a0b8366b40a4920b84b34c655 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Tue, 6 May 2025 14:58:47 -0700 Subject: [PATCH 07/15] Address comments Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 108 +++++++++++------- .../serve/cpu/deployments/test_prefix_tree.py | 83 +++++++++++++- 2 files changed, 143 insertions(+), 48 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index 90fcb765571e1..989c6c3cd622c 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -14,38 +14,15 @@ class Node: """ - Node in a prefix tree that tracks tenant access time. - - Each node represents a segment of text and can belong to multiple tenants. - - Example tree structure: - Representing the strings inserted in order: - - "helloworld" at time 1 by tenant_1 - - "hellothere" at time 2 by tenant_2 - - "hellothomas" at time 3 by tenant_2 - - root: [] - {tenant_1: 1, tenant_2: 3} - | - (h)| - | - [hello] - {tenant_1: 1, tenant_2: 3} - / \ - (w)/ \\(t) - / \ - [world] [th] - {tenant_1: 1} {tenant_2: 3} - / \ - (e)/ \\(o) - / \ - [ere] [omas] - {tenant_2: 2} {tenant_2: 3} - - Legend for each node: - - [text] = Node.text - - {tenant, timestamp} = Node.tenant_last_access_time - - (x) = edge label (first character used as key for parent's children) + Node in a prefix tree that represents a segment of text and can belong to multiple tenants. + Each node also tracks the last access time for each tenant. + Simple example of root node connected to two children Nodes: + root = Node(text="", parent=None, children={"f": fooNode, "b": barNode}, tenant_last_access_time={"tenant_1": 2}) + fooNode = Node(text="foo", parent=root, children={}, tenant_last_access_time={"tenant_1": 1}) + barNode = Node(text="bar", parent=root, children={}, tenant_last_access_time={"tenant_1": 2}) + + In the above example, "foo" was inserted at time 1, and "bar" was inserted at time 2. + It follows that root was last accessed at time 2. """ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: @@ -57,11 +34,11 @@ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: parent: The parent node of this node """ self.text: str = text - self.parent: Optional[Node] = parent + self.parent: Optional[Node] = parent # The parent node of this node self.children: Dict[str, Node] = {} # Maps first character to child node self.tenant_last_access_time: Dict[str, int] = ( {} - ) # Maps tenant ID to last access timestamp (in milliseconds) + ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in milliseconds) def __repr__(self) -> str: return f"Node(text='{self.text}', children={list(self.children.keys())}, tenants={list(self.tenant_last_access_time.keys())})" @@ -73,7 +50,7 @@ class TenantHeapNode: Used for efficient LRU eviction of tenant nodes. """ - def __init__(self, node: Node, tenant: str) -> None: + def __init__(self, node: Node, tenant_ordering_key: str) -> None: """ Initialize a heap node for efficient LRU tenant management. @@ -82,7 +59,7 @@ def __init__(self, node: Node, tenant: str) -> None: tenant: The tenant ID this heap node is associated with """ self.node = node - self.tenant_ordering_key = tenant + self.tenant_ordering_key = tenant_ordering_key def __lt__(self, other: TenantHeapNode) -> bool: """ @@ -113,22 +90,48 @@ class PrefixTree: 2. Thread-safe with node-level locking for concurrent access 3. LRU eviction based on tenant access time 4. Efficient prefix matching across multiple tenants + + + Example tree structure: + Representing the strings inserted in order: + - "helloworld" at time 1 by tenant_1 + - "hellothere" at time 2 by tenant_2 + - "hellothomas" at time 3 by tenant_2 + + root: [] {tenant_1: 1, tenant_2: 3} + (h) → [hello] {tenant_1: 1, tenant_2: 3} + (w) → [world] {tenant_1: 1} + (t) → [th] {tenant_2: 3} + (e) → [ere] {tenant_2: 2} + (o) → [omas] {tenant_2: 3} + + Legend for each node: + - [text] = Node.text + - {tenant, timestamp} = Node.tenant_last_access_time + - (x) = edge label (first character used as key for parent's children) + + PrefixTree instance variables: + self.tenants = {"tenant_1", "tenant_2"} + self.tenant_char_count = {"tenant_1": 10, "tenant_2": 14} + self.tenant_nodes = {"tenant_1": {root, Node("hello"), Node("world")}, "tenant_2": {root, Node("hello"), Node("th"), Node("ere"), Node("omas")}} + self.tenant_nodes_sorted = {"tenant_1": [root, Node("hello"), Node("world")], "tenant_2": [Node("ere"), root, Node("hello"), Node("th"), Node("omas")]} + # Note: self.tenant_nodes_sorted is maintained as a min-heap, so the first element is guaranteed to be the least recently used node for that tenant, but the rest of the heap is not guaranteed to be sorted. """ def __init__(self) -> None: """Initialize an empty prefix tree.""" self.lock: RLock = RLock() self.root: Node = Node() - self.tenants: Set[str] = set() # Set of tenant IDs + self.tenants: Set[str] = set() # Set of tenant IDs in the tree self.tenant_char_count: Dict[str, int] = ( {} ) # Tracks total character count per tenant self.tenant_nodes: Dict[str, Set[Node]] = ( {} - ) # Maps tenant ID to set of nodes belonging to that tenant + ) # Maps tenant ID to set of nodes belonging to that tenant. Used for O(1) lookup of whether a node belongs to a tenant. self.tenant_nodes_sorted: Dict[str, List[TenantHeapNode]] = ( {} - ) # Maps tenant ID to heap of nodes for LRU eviction + ) # Maps tenant ID to heap of nodes for LRU eviction. Used for O(log n) insertion and eviction of LRU node. def reset(self) -> None: """Reset the tree to an empty state.""" @@ -185,6 +188,23 @@ def insert(self, text: str, tenant: str, timestamp_ms: int) -> Node: Returns: The node that was inserted or updated + + Note: + Loop structure: + 1. At the start of each iteration, curr_node is a node we potentially update. + e.g. node.tenant_last_access_time[tenant], self.tenant_char_count, + self.tenant_nodes, self.tenant_nodes_sorted + 2. Each iteration then either: + a. Breaks (if we've processed the entire string). + b. Processes the next segment of text by: + 1. If no child exists for the first character, create a new leaf node that matches the current text. + 2. Then, match the current text with the child's text: + a. If they share a prefix (partial match), split the node and traverse into the new parent. + b. If they fully match, traverse into the child node. + 3. The self.tenant_nodes_sorted heap is reheapified at each node visit to maintain LRU order. + + This structure allows us to efficiently insert text while maintaining shared prefixes + and tracking tenant access times for the LRU eviction mechanism. """ with self.lock: if tenant not in self.tenants: @@ -197,15 +217,14 @@ def insert(self, text: str, tenant: str, timestamp_ms: int) -> Node: # Invariant: assume curr_node has not been visited by tenant yet # Update tenant info for current node if tenant not in curr_node.tenant_last_access_time: - self.tenant_nodes[tenant].add(curr_node) self.tenant_char_count[tenant] += len(curr_node.text) + self.tenant_nodes[tenant].add(curr_node) self.tenant_nodes_sorted[tenant].append( TenantHeapNode(curr_node, tenant) ) curr_node.tenant_last_access_time[tenant] = timestamp_ms heapq.heapify(self.tenant_nodes_sorted[tenant]) - if i == len(text): break @@ -249,6 +268,11 @@ def insert(self, text: str, tenant: str, timestamp_ms: int) -> Node: new_parent.tenant_last_access_time = ( matched_node.tenant_last_access_time.copy() ) + for existing_tenant in new_parent.tenant_last_access_time: + self.tenant_nodes[existing_tenant].add(new_parent) + self.tenant_nodes_sorted[existing_tenant].append( + TenantHeapNode(new_parent, existing_tenant) + ) # Update existing matched node matched_node.text = remaining_text @@ -410,7 +434,7 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: return removed_chars_len - def evict_tenant_by_LRU(self, tenant: str, min_remove_size: int) -> int: + def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: """ Evict least recently used nodes for a tenant until minimum size is freed. diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index f9214546af86b..a416ad6815a12 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -2,6 +2,7 @@ import time import ray from ray import serve +import heapq from ray.llm._internal.serve.replica_scheduler.prefix_aware.prefix_tree import ( PrefixTree, @@ -262,8 +263,8 @@ async def test__remove_tenant_single_node(): @pytest.mark.asyncio -async def test_evict_tenant_by_LRU(): - """Test the evict_tenant_by_LRU functionality of PrefixTree.""" +async def test_evict_tenant_by_lru(): + """Test the evict_tenant_by_lru functionality of PrefixTree.""" tree = serve.run(PrefixTree.bind()) # 1. Test eviction with LRU ordering @@ -280,7 +281,7 @@ async def test_evict_tenant_by_LRU(): tree_rep = await tree.to_dict.remote() before = tree_rep["tenant_char_count"]["tenant_1"] - evicted = await tree.evict_tenant_by_LRU.remote("tenant_1", 2) + evicted = await tree.evict_tenant_by_lru.remote("tenant_1", 2) tree_rep = await tree.to_dict.remote() after = tree_rep["tenant_char_count"]["tenant_1"] @@ -292,13 +293,13 @@ async def test_evict_tenant_by_LRU(): # 2. Test eviction of non-existent tenant raises ValueError await tree.reset.remote() with pytest.raises(ValueError): - await tree.evict_tenant_by_LRU.remote("nonexistent_tenant", 5) + await tree.evict_tenant_by_lru.remote("nonexistent_tenant", 5) # 3. Test eviction of tenant with insufficient characters raises ValueError await tree.reset.remote() await tree.insert.remote("xyz", "tenant_2", int(time.time() * 1000)) with pytest.raises(ValueError): - await tree.evict_tenant_by_LRU.remote("tenant_2", 4) + await tree.evict_tenant_by_lru.remote("tenant_2", 4) # 4. Test eviction of all tenant data await tree.reset.remote() @@ -307,12 +308,82 @@ async def test_evict_tenant_by_LRU(): tree_rep = await tree.to_dict.remote() total_size = tree_rep["tenant_char_count"]["tenant_2"] - evicted = await tree.evict_tenant_by_LRU.remote("tenant_2", total_size) + evicted = await tree.evict_tenant_by_lru.remote("tenant_2", total_size) assert evicted == total_size tree_rep = await tree.to_dict.remote() assert "tenant_2" in tree_rep["tenants"] + # 5. Test tree structure and LRU heap ordering + await tree.reset.remote() + + # Insert strings in specified order + await tree.insert.remote("helloworld", "tenant_1", 1) # time 1 for tenant_1 + await tree.insert.remote("hellothere", "tenant_2", 2) # time 2 for tenant_2 + await tree.insert.remote("hellothomas", "tenant_2", 3) # time 3 for tenant_2 + + # Get tree representation for testing + tree_rep = await tree.to_dict.remote() + root = tree_rep["root"] + + # Test tree structure - validate each node + # Root node + assert root.text == "" + assert root.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "h" in root.children + + # Hello node + hello_node = root.children["h"] + assert hello_node.text == "hello" + assert hello_node.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "w" in hello_node.children + assert "t" in hello_node.children + + # World node + world_node = hello_node.children["w"] + assert world_node.text == "world" + assert world_node.tenant_last_access_time == {"tenant_1": 1} + assert len(world_node.children) == 0 + + # Th node + th_node = hello_node.children["t"] + assert th_node.text == "th" + assert th_node.tenant_last_access_time == {"tenant_2": 3} + assert "e" in th_node.children + assert "o" in th_node.children + + # Ere node + ere_node = th_node.children["e"] + assert ere_node.text == "ere" + assert ere_node.tenant_last_access_time == {"tenant_2": 2} + assert len(ere_node.children) == 0 + + # Omas node + omas_node = th_node.children["o"] + assert omas_node.text == "omas" + assert omas_node.tenant_last_access_time == {"tenant_2": 3} + assert len(omas_node.children) == 0 + + # Test PrefixTree instance variables + assert tree_rep["tenants"] == {"tenant_1", "tenant_2"} + + # Test tenant_char_count + assert tree_rep["tenant_char_count"]["tenant_1"] == 10 # root(0) + hello(5) + world(5) = 10 + assert tree_rep["tenant_char_count"]["tenant_2"] == 14 # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 + + # Test tenant_nodes (check by text) + tenant1_nodes_texts = {node.text for node in tree_rep["tenant_nodes"]["tenant_1"]} + assert tenant1_nodes_texts == {"", "hello", "world"} + + tenant2_nodes_texts = {node.text for node in tree_rep["tenant_nodes"]["tenant_2"]} + assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} + + # Test tenant_nodes_sorted - validate heap ordering + assert heapq.heappop(tree_rep["tenant_nodes_sorted"]["tenant_1"]).node.tenant_last_access_time["tenant_1"] == 1 + assert heapq.heappop(tree_rep["tenant_nodes_sorted"]["tenant_1"]).node.tenant_last_access_time["tenant_1"] == 1 + assert heapq.heappop(tree_rep["tenant_nodes_sorted"]["tenant_2"]).node.tenant_last_access_time["tenant_2"] == 2 + assert heapq.heappop(tree_rep["tenant_nodes_sorted"]["tenant_2"]).node.tenant_last_access_time["tenant_2"] == 3 + @pytest.mark.asyncio async def test_get_smallest_tenant(): From 8878844e1fcd68d2ce15c5ca406508b2d71fff00 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Tue, 6 May 2025 18:17:40 -0700 Subject: [PATCH 08/15] Clean up code, separate base class from serve deployment Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 226 +++---- .../serve/cpu/deployments/test_prefix_tree.py | 621 +++++++++++------- 2 files changed, 482 insertions(+), 365 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index 989c6c3cd622c..e3fd81cd139cb 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -4,11 +4,10 @@ import logging import os from threading import RLock -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Optional, Set, Tuple, Any from ray import serve -# Logger for this module logger = logging.getLogger(__name__) @@ -80,7 +79,6 @@ def __repr__(self) -> str: return f"TenantHeapNode(node={self.node}, tenant_ordering_key={self.tenant_ordering_key})" -@serve.deployment(name="TreeDeployment") class PrefixTree: """ Thread-safe multi-tenant prefix tree (approximate radix tree). @@ -133,8 +131,28 @@ def __init__(self) -> None: {} ) # Maps tenant ID to heap of nodes for LRU eviction. Used for O(log n) insertion and eviction of LRU node. - def reset(self) -> None: - """Reset the tree to an empty state.""" + + @staticmethod + def _shared_prefix_count(a: str, b: str) -> int: + """ + Count the number of shared characters at the beginning of two strings. + + Args: + a: First string + b: Second string + + Returns: + Number of matching characters at the beginning + """ + return len(os.path.commonprefix([a, b])) + + + def _reset(self) -> None: + """ + Reset the tree to an empty state. + + Note: This method is intended to be used only in tests. + """ with self.lock: self.root = Node() self.tenants = set() @@ -142,40 +160,69 @@ def reset(self) -> None: self.tenant_nodes = {} self.tenant_nodes_sorted = {} - def to_dict(self) -> Dict: + + def _add_tenant(self, tenant: str) -> None: """ - Convert tree to dictionary for serialization. + Add a new tenant to the tree. - Returns: - Dictionary representation of the tree + If the tenant already exists, this is a no-op with a warning log. + + Args: + tenant: Tenant ID to add """ - return { - "root": self.root, - "tenants": self.tenants, - "tenant_char_count": self.tenant_char_count, - "tenant_nodes": self.tenant_nodes, - "tenant_nodes_sorted": self.tenant_nodes_sorted, - } + with self.lock: + if tenant in self.tenants: + logger.warning(f"Tenant '{tenant}' already exists. No action taken.") + return + + self.tenants.add(tenant) + self.tenant_char_count[tenant] = 0 + self.tenant_nodes[tenant] = set() + self.tenant_nodes_sorted[tenant] = [] - def to_string(self) -> str: - """String representation of the tree.""" - return f"PrefixTree(tenants={self.tenants}, tenant_char_count={self.tenant_char_count}, tenant_nodes={self.tenant_nodes}, tenant_nodes_sorted={self.tenant_nodes_sorted})" - @staticmethod - def _shared_prefix_count(a: str, b: str) -> int: + def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: """ - Count the number of shared characters at the beginning of two strings. + Remove a tenant from a single node. Args: - a: First string - b: Second string + tenant: Tenant ID to remove + node: Node to remove tenant from Returns: - Number of matching characters at the beginning + Number of characters removed """ - return len(os.path.commonprefix([a, b])) + with self.lock: + if tenant not in self.tenants: + logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") + return 0 + + if ( + node not in self.tenant_nodes[tenant] + or tenant not in node.tenant_last_access_time + ): + logger.warning( + f"Cannot remove node '{node.text}' from tenant '{tenant}': " + f"tenant does not have this node. No action taken." + ) + return 0 - def insert(self, text: str, tenant: str, timestamp_ms: int) -> Node: + removed_chars_len: int = len(node.text) + self.tenant_char_count[tenant] -= removed_chars_len + self.tenant_nodes[tenant].remove(node) + node.tenant_last_access_time.pop(tenant, None) + + # Clean up empty nodes + if not node.tenant_last_access_time and node.parent: + if ( + node.text and node.text[0] in node.parent.children + ): # Defensive check + node.parent.children.pop(node.text[0], None) + + return removed_chars_len + + + def insert(self, text: str, tenant: str, time_sec: float) -> Node: """ Insert text into tree for a specific tenant. @@ -184,7 +231,7 @@ def insert(self, text: str, tenant: str, timestamp_ms: int) -> Node: Args: text: Text to insert tenant: Tenant ID - timestamp_ms: Current timestamp in milliseconds + time_sec: Current timestamp in seconds Returns: The node that was inserted or updated @@ -223,7 +270,7 @@ def insert(self, text: str, tenant: str, timestamp_ms: int) -> Node: TenantHeapNode(curr_node, tenant) ) - curr_node.tenant_last_access_time[tenant] = timestamp_ms + curr_node.tenant_last_access_time[tenant] = time_sec heapq.heapify(self.tenant_nodes_sorted[tenant]) if i == len(text): break @@ -292,6 +339,7 @@ def insert(self, text: str, tenant: str, timestamp_ms: int) -> Node: return curr_node + def prefix_match( self, text: str, available_tenants: Optional[List[str]] = None ) -> Tuple[str, Optional[List[str]]]: @@ -352,14 +400,12 @@ def prefix_match( tenant for tenant in available_tenants if tenant in curr_node.tenant_last_access_time - ] + ] or None - selected_tenants: Optional[List[str]] = ( - matching_tenants if matching_tenants else None - ) matched_text: str = text[:i] - return matched_text, selected_tenants + return matched_text, matching_tenants + def remove_tenant(self, tenant: str) -> int: """ @@ -370,15 +416,11 @@ def remove_tenant(self, tenant: str) -> int: Returns: Number of characters removed - - Raises: - ValueError: If tenant does not exist """ with self.lock: if tenant not in self.tenants: - raise ValueError( - f"Cannot remove tenant '{tenant}': tenant does not exist" - ) + logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") + return 0 total_chars_removed: int = 0 for node in self.tenant_nodes[tenant].copy(): @@ -391,48 +433,6 @@ def remove_tenant(self, tenant: str) -> int: return total_chars_removed - def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: - """ - Remove a tenant from a single node. - - Args: - tenant: Tenant ID to remove - node: Node to remove tenant from - - Returns: - Number of characters removed - - Raises: - ValueError: If tenant does not exist or node doesn't belong to tenant - """ - with self.lock: - if tenant not in self.tenants: - raise ValueError( - f"Cannot remove tenant '{tenant}': tenant does not exist" - ) - - if ( - node not in self.tenant_nodes[tenant] - or tenant not in node.tenant_last_access_time - ): - raise ValueError( - f"Cannot remove node '{node.text}' from tenant '{tenant}': " - f"tenant does not have this node" - ) - - removed_chars_len: int = len(node.text) - self.tenant_char_count[tenant] -= removed_chars_len - self.tenant_nodes[tenant].remove(node) - node.tenant_last_access_time.pop(tenant, None) - - # Clean up empty nodes - if not node.tenant_last_access_time and node.parent: - if ( - node.text and node.text[0] in node.parent.children - ): # Defensive check - node.parent.children.pop(node.text[0], None) - - return removed_chars_len def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: """ @@ -444,22 +444,20 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: Returns: Actual number of characters removed - - Raises: - ValueError: If tenant doesn't exist or has insufficient nodes """ with self.lock: if tenant not in self.tenant_nodes or not self.tenant_nodes[tenant]: - raise ValueError( - f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes" + logger.warning( + f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes. No action taken." ) + return 0 if self.tenant_char_count[tenant] < min_remove_size: - raise ValueError( - f"Cannot evict tenant '{tenant}': total character count " - f"({self.tenant_char_count[tenant]}) is less than min_remove_size " - f"({min_remove_size})" + logger.warning( + f"Cannot evict {min_remove_size} characters from tenant '{tenant}', which has only " + f"{self.tenant_char_count[tenant]} characters. Will remove all available characters." ) + min_remove_size = self.tenant_char_count[tenant] total_chars_removed: int = 0 @@ -468,15 +466,20 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: total_chars_removed < min_remove_size and self.tenant_nodes_sorted[tenant] ): - heap_node: TenantHeapNode = heapq.heappop( - self.tenant_nodes_sorted[tenant] - ) - total_chars_removed += self._remove_tenant_single_node( - tenant, heap_node.node - ) + # Get the minimum access time from the top of the heap + oldest_access_time = self.tenant_nodes_sorted[tenant][0].node.tenant_last_access_time[tenant] + + # Remove all nodes with this same access time + while (self.tenant_nodes_sorted[tenant] and + self.tenant_nodes_sorted[tenant][0].node.tenant_last_access_time[tenant] == oldest_access_time): + heap_node: TenantHeapNode = heapq.heappop(self.tenant_nodes_sorted[tenant]) + total_chars_removed += self._remove_tenant_single_node( + tenant, heap_node.node + ) return total_chars_removed + def get_smallest_tenant(self) -> Optional[str]: """ Get the tenant with the smallest total character count. @@ -490,21 +493,22 @@ def get_smallest_tenant(self) -> Optional[str]: return min(self.tenant_char_count, key=self.tenant_char_count.get, default=None) - def _add_tenant(self, tenant: str) -> None: + +@serve.deployment(name="TreeDeployment") +class PrefixTreeDeployment(PrefixTree): + def _to_dict(self) -> Dict[str, Any]: """ - Add a new tenant to the tree. + Convert tree to dictionary for serialization. - If the tenant already exists, this is a no-op with a warning log. + Returns: + Dictionary representation of the tree - Args: - tenant: Tenant ID to add + Note: This method is intended to be used only in tests. """ - with self.lock: - if tenant in self.tenants: - logger.warning(f"Tenant '{tenant}' already exists. No action taken.") - return - - self.tenants.add(tenant) - self.tenant_char_count[tenant] = 0 - self.tenant_nodes[tenant] = set() - self.tenant_nodes_sorted[tenant] = [] + return { + "root": self.root, + "tenants": self.tenants, + "tenant_char_count": self.tenant_char_count, + "tenant_nodes": self.tenant_nodes, + "tenant_nodes_sorted": self.tenant_nodes_sorted, + } \ No newline at end of file diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index a416ad6815a12..abaea937871b4 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -3,14 +3,20 @@ import ray from ray import serve import heapq +from typing import Set, List, Dict, Optional, Generator, Any from ray.llm._internal.serve.replica_scheduler.prefix_aware.prefix_tree import ( - PrefixTree, + PrefixTree, PrefixTreeDeployment, Node, TenantHeapNode ) +# Fixtures +@pytest.fixture +def tree() -> PrefixTree: + """Create a fresh PrefixTree instance for each test.""" + return PrefixTree() @pytest.fixture(scope="module", autouse=True) -def serve_instance(): +def serve_instance() -> Generator[None, None, None]: # Start Ray and Serve once per test module ray.init(ignore_reinit_error=True) serve.start(detached=True) @@ -18,313 +24,422 @@ def serve_instance(): serve.shutdown() ray.shutdown() +@pytest.fixture(scope="module") +def tree_deployment(): + """Create a fresh PrefixTreeDeployment instance for each test.""" + tree = serve.run(PrefixTreeDeployment.bind()) + return tree +# PrefixTreeDeployment tests @pytest.mark.asyncio -async def test_add_tenant(): - """Test adding tenants to the tree via the private _add_tenant method.""" - tree = serve.run(PrefixTree.bind()) +async def test_tree_deployment(tree_deployment) -> None: + """Test the PrefixTreeDeployment.""" + # 6. Test tree structure and LRU heap ordering + await tree_deployment._reset.remote() + + # Insert strings in specified order + await tree_deployment.insert.remote("helloworld", "tenant_1", 1) # time 1 for tenant_1 + await tree_deployment.insert.remote("hellothere", "tenant_2", 2) # time 2 for tenant_2 + await tree_deployment.insert.remote("hellothomas", "tenant_2", 3) # time 3 for tenant_2 + + # Access tree directly + tree_rep: Dict = await tree_deployment._to_dict.remote() + root: Node = tree_rep["root"] + + # Test tree structure - validate each node + # Root node + assert root.text == "" + assert root.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "h" in root.children + + # Hello node + hello_node: Node = root.children["h"] + assert hello_node.text == "hello" + assert hello_node.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "w" in hello_node.children + assert "t" in hello_node.children + + # World node + world_node: Node = hello_node.children["w"] + assert world_node.text == "world" + assert world_node.tenant_last_access_time == {"tenant_1": 1} + assert len(world_node.children) == 0 + + # Th node + th_node: Node = hello_node.children["t"] + assert th_node.text == "th" + assert th_node.tenant_last_access_time == {"tenant_2": 3} + assert "e" in th_node.children + assert "o" in th_node.children + + # Ere node + ere_node: Node = th_node.children["e"] + assert ere_node.text == "ere" + assert ere_node.tenant_last_access_time == {"tenant_2": 2} + assert len(ere_node.children) == 0 + + # Omas node + omas_node: Node = th_node.children["o"] + assert omas_node.text == "omas" + assert omas_node.tenant_last_access_time == {"tenant_2": 3} + assert len(omas_node.children) == 0 + + # Test PrefixTree instance variables + assert tree_rep["tenants"] == {"tenant_1", "tenant_2"} + + # Test tenant_char_count + assert tree_rep["tenant_char_count"]["tenant_1"] == 10 # root(0) + hello(5) + world(5) = 10 + assert tree_rep["tenant_char_count"]["tenant_2"] == 14 # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 + + # Test tenant_nodes (check by text) + tenant1_nodes_texts: Set[str] = {node.text for node in tree_rep["tenant_nodes"]["tenant_1"]} + assert tenant1_nodes_texts == {"", "hello", "world"} + + tenant2_nodes_texts: Set[str] = {node.text for node in tree_rep["tenant_nodes"]["tenant_2"]} + assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} + + # Test tenant_nodes_sorted - validate heap ordering + tenant1_heap: List[TenantHeapNode] = tree_rep["tenant_nodes_sorted"]["tenant_1"] + tenant2_heap: List[TenantHeapNode] = tree_rep["tenant_nodes_sorted"]["tenant_2"] + + assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 + assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 + assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 2 + assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 3 + +# PrefixTree tests +def test__add_tenant(tree: PrefixTree) -> None: + """Test adding tenants to the tree via the private _add_tenant method.""" # 1. Test basic tenant addition - await tree.reset.remote() - await tree._add_tenant.remote("tenant_1") - tree_rep = await tree.to_dict.remote() - assert "tenant_1" in tree_rep["tenants"] - assert tree_rep["tenant_char_count"]["tenant_1"] == 0 - assert tree_rep["tenant_nodes"]["tenant_1"] == set() + tree._reset() + tree._add_tenant("tenant_1") + assert "tenant_1" in tree.tenants + assert tree.tenant_char_count["tenant_1"] == 0 + assert tree.tenant_nodes["tenant_1"] == set() # 2. Test adding duplicate tenant logs warning but doesn't raise error - await tree.reset.remote() - await tree._add_tenant.remote("tenant_1") - # This should not raise an error now - await tree._add_tenant.remote("tenant_1") + tree._reset() + tree._add_tenant("tenant_1") + # This should be a no-op + tree._add_tenant("tenant_1") # Verify the tenant still exists - tree_rep = await tree.to_dict.remote() - assert "tenant_1" in tree_rep["tenants"] + assert "tenant_1" in tree.tenants -@pytest.mark.asyncio -async def test_insert(): +def test_insert(tree: PrefixTree) -> None: """Test the insert functionality of PrefixTree.""" - tree = serve.run(PrefixTree.bind()) - # 1. Test basic insertion - await tree.reset.remote() + tree._reset() # No need to call add_tenant first - insert will do it automatically - await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) - matched_text, tenants = await tree.prefix_match.remote("hello") + tree.insert("hello", "tenant_1", 1) + matched_text, tenants = tree.prefix_match("hello") assert matched_text == "hello" assert tenants == ["tenant_1"] - tree_rep = await tree.to_dict.remote() - assert tree_rep["tenant_char_count"]["tenant_1"] == 5 - assert len(tree_rep["tenant_nodes"]["tenant_1"]) == 2 + assert tree.tenant_char_count["tenant_1"] == 5 + assert len(tree.tenant_nodes["tenant_1"]) == 2 # 2. Test duplicate insertion doesn't double count - await tree.reset.remote() - # Insert automatically adds tenants - await tree.insert.remote("foo", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("foo", "tenant_1", int(time.time() * 1000)) # duplicate - await tree.insert.remote("bar", "tenant_2", int(time.time() * 1000)) + tree._reset() + tree.insert("foo", "tenant_1", 1) + tree.insert("foo", "tenant_1", 1) # duplicate + tree.insert("bar", "tenant_2", 2) - tree_rep = await tree.to_dict.remote() - assert tree_rep["tenant_char_count"]["tenant_1"] == 3 - assert tree_rep["tenant_char_count"]["tenant_2"] == 3 + assert tree.tenant_char_count["tenant_1"] == 3 + assert tree.tenant_char_count["tenant_2"] == 3 # 3. Test node splitting on partial match - await tree.reset.remote() - await tree.insert.remote("helloworld", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("hellothere", "tenant_2", int(time.time() * 1000)) + tree._reset() + tree.insert("helloworld", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) - tree_rep = await tree.to_dict.remote() - root = tree_rep["root"] - h_node = root.children.get("h") + root: Node = tree.root + h_node: Optional[Node] = root.children.get("h") assert h_node is not None assert h_node.text == "hello" assert h_node.children.get("w").text == "world" assert h_node.children.get("t").text == "there" + + # 4. Test that inserting a longer prompt with shared prefix doesn't create empty text nodes + tree._reset() + tree.insert("hello", "tenant_1", 1) + tree.insert("helloworld", "tenant_2", 2) + + root = tree.root + + # Check that only the root has empty text by directly traversing the tree + # Starting from root, collect all nodes with empty text + empty_text_nodes: List[Node] = [] + nodes_to_check: List[Node] = [root] + + while nodes_to_check: + node: Node = nodes_to_check.pop() + if node.text == "": + empty_text_nodes.append(node) + # Add all children to check + nodes_to_check.extend(node.children.values()) + + # There should be exactly one empty text node (the root) + assert len(empty_text_nodes) == 1 + assert root in empty_text_nodes + + # Verify tree structure + h_node = root.children.get("h") + assert h_node is not None + assert h_node.text == "hello" + assert "tenant_1" in h_node.tenant_last_access_time + assert "tenant_2" in h_node.tenant_last_access_time + + # Verify "world" node belongs only to tenant 2 + world_node: Optional[Node] = h_node.children.get("w") + assert world_node is not None + assert world_node.text == "world" + assert "tenant_2" in world_node.tenant_last_access_time + assert "tenant_1" not in world_node.tenant_last_access_time - # 4. Test inserting for non-existent tenant automatically adds the tenant - await tree.reset.remote() - # This should not raise an error now - await tree.insert.remote("hello", "nonexistent_tenant", int(time.time() * 1000)) - - # Verify the tenant was added - tree_rep = await tree.to_dict.remote() - assert "nonexistent_tenant" in tree_rep["tenants"] - assert tree_rep["tenant_char_count"]["nonexistent_tenant"] == 5 + # Verify the only child of h_node is "w" + assert len(h_node.children) == 1 -@pytest.mark.asyncio -async def test_prefix_match(): +def test_prefix_match(tree: PrefixTree) -> None: """Test the prefix_match functionality of PrefixTree.""" - tree = serve.run(PrefixTree.bind()) - - # # 1. Test no match - # await tree.reset.remote() - # matched_text, tenants = await tree.prefix_match.remote("hello") - # assert matched_text == "" - # assert tenants is None + # 1. Test no match + tree._reset() + matched_text, tenants = tree.prefix_match("hello") + assert matched_text == "" + assert tenants is None # 2. Test match with non-existing prefix returns empty string and all tenants - await tree.reset.remote() - await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("hellothere", "tenant_2", int(time.time() * 1000)) - matched_text, tenants = await tree.prefix_match.remote("foobar") + tree._reset() + tree.insert("hello", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) + matched_text, tenants = tree.prefix_match("foobar") assert matched_text == "" assert len(tenants) == 2 assert "tenant_1" in tenants assert "tenant_2" in tenants # 3. Test exact match - await tree.reset.remote() - await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) - matched_text, tenants = await tree.prefix_match.remote("hello") + tree._reset() + tree.insert("hello", "tenant_1", 1) + matched_text, tenants = tree.prefix_match("hello") assert matched_text == "hello" assert tenants == ["tenant_1"] # 4. Test partial match - await tree.reset.remote() - await tree.insert.remote("apple", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("apricot", "tenant_2", int(time.time() * 1000)) - text, tenants = await tree.prefix_match.remote("application") + tree._reset() + tree.insert("apple", "tenant_1", 1) + tree.insert("apricot", "tenant_2", 2) + text, tenants = tree.prefix_match("application") assert text == "appl" assert tenants == ["tenant_1"] # 5. Test match by tenant - await tree.reset.remote() - await tree.insert.remote("apple", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("apricot", "tenant_2", int(time.time() * 1000)) - text, tenants = await tree.prefix_match.remote("application", ["tenant_2"]) + tree._reset() + tree.insert("apple", "tenant_1", 1) + tree.insert("apricot", "tenant_2", 2) + text, tenants = tree.prefix_match("application", ["tenant_2"]) assert text == "ap" assert tenants == ["tenant_2"] # 6. Test match by non-existent tenant - await tree.reset.remote() - await tree.insert.remote("apple", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("apricot", "tenant_2", int(time.time() * 1000)) - text, tenants = await tree.prefix_match.remote("application", ["tenant_3"]) + tree._reset() + tree.insert("apple", "tenant_1", 1) + tree.insert("apricot", "tenant_2", 2) + text, tenants = tree.prefix_match("application", ["tenant_3"]) assert text == "" assert tenants is None # 7. Test shared prefix matching with branches - await tree.reset.remote() - await tree.insert.remote("helloworld", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("hellothere", "tenant_2", int(time.time() * 1000)) - text_a, tenants_a = await tree.prefix_match.remote("helloworld") - text_b, tenants_b = await tree.prefix_match.remote("hellothereworld") + tree._reset() + tree.insert("helloworld", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) + text_a, tenants_a = tree.prefix_match("helloworld") + text_b, tenants_b = tree.prefix_match("hellothereworld") assert text_a == "helloworld" assert tenants_a == ["tenant_1"] assert text_b == "hellothere" assert tenants_b == ["tenant_2"] -@pytest.mark.asyncio -async def test_remove_tenant(): - """Test removing a tenant from the tree.""" - tree = serve.run(PrefixTree.bind()) - - # 1. Test basic tenant removal - await tree.reset.remote() - await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) - removed = await tree.remove_tenant.remote("tenant_1") - assert removed == 5 - - tree_rep = await tree.to_dict.remote() - assert "tenant_1" not in tree_rep["tenants"] - assert "tenant_1" not in tree_rep["tenant_char_count"] - assert "tenant_1" not in tree_rep["tenant_nodes"] - - # 2. Test removing tenant with multiple nodes - await tree.reset.remote() - await tree.insert.remote("cat", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("dog", "tenant_1", int(time.time() * 1000)) - removed = await tree.remove_tenant.remote("tenant_1") - assert removed == len("cat") + len("dog") - - # 3. Test removing non-existent tenant raises ValueError - await tree.reset.remote() - with pytest.raises(ValueError): - await tree.remove_tenant.remote("nonexistent_tenant") - - # 4. Test tree structure after removing tenant - await tree.reset.remote() - await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("hello", "tenant_2", int(time.time() * 1000)) - - # Remove tenant_1, verify tenant_2 still works - await tree.remove_tenant.remote("tenant_1") - - tree_rep = await tree.to_dict.remote() - assert "tenant_1" not in tree_rep["tenants"] - assert "tenant_2" in tree_rep["tenants"] - - matched_text, tenants = await tree.prefix_match.remote("hello") - assert matched_text == "hello" - assert tenants == ["tenant_2"] - - # 5. Test removing the last tenant from a node removes the node - await tree.reset.remote() - await tree.insert.remote("unique1", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("unique2", "tenant_2", int(time.time() * 1000)) - - # Remove tenant_1 - await tree.remove_tenant.remote("tenant_1") - - tree_rep = await tree.to_dict.remote() - root = tree_rep["root"] - # 'u' node should only have one child now ('2' from unique2) - assert "u" in root.children - assert "2" in root.children["u"].children # '2' from unique2 - assert len(root.children["u"].children) == 1 - - -@pytest.mark.asyncio -async def test__remove_tenant_single_node(): +def test__remove_tenant_single_node(tree: PrefixTree) -> None: """Test removing a single node for a tenant.""" - tree = serve.run(PrefixTree.bind()) - - # # 1. Test removing a single node + # 1. Test removing a single node # TEST FAILS: Ray creates new node instances when making remote calls? # The node from insert.remote() is not identity-equal to the one in tenant_nodes - # await tree.reset.remote() - # await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) - # h_node = await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) + tree._reset() + tree.insert("hello", "tenant_1", 1) + h_node: Node = tree.insert("hello", "tenant_1", 1) - # removed = await tree._remove_tenant_single_node.remote("tenant_1", h_node) - # assert removed == 5 - - # tree_rep = await tree.to_dict.remote() - # assert tree_rep["tenant_char_count"]["tenant_1"] == 0 - # assert tree_rep["tenant_nodes"]["tenant_1"] == set() + removed: int = tree._remove_tenant_single_node("tenant_1", h_node) + assert removed == 5 - # 2. Test removing node for non-existent tenant raises ValueError - await tree.reset.remote() - await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) + assert tree.tenant_char_count["tenant_1"] == 0 + assert len(tree.tenant_nodes["tenant_1"]) == 1 + assert tree.root in tree.tenant_nodes["tenant_1"] - tree_rep = await tree.to_dict.remote() - root = tree_rep["root"] - h_node = root.children.get("h") + # 2. Test removing node for non-existent tenant is idempotent + tree._reset() + tree.insert("hello", "tenant_1", 1) + root: Node = tree.root + h_node: Optional[Node] = root.children.get("h") - with pytest.raises(ValueError): - await tree._remove_tenant_single_node.remote("nonexistent_tenant", h_node) + # Should not raise error, just return 0 + removed = tree._remove_tenant_single_node("nonexistent_tenant", h_node) + assert removed == 0 - # 3. Test removing node that doesn't belong to tenant raises ValueError - await tree.reset.remote() - await tree.insert.remote("hello", "tenant_1", int(time.time() * 1000)) - await tree.insert.remote("world", "tenant_2", int(time.time() * 1000)) + # 3. Test removing node that doesn't belong to tenant is idempotent + tree._reset() + tree.insert("hello", "tenant_1", 1) + tree.insert("world", "tenant_2", 2) - tree_rep = await tree.to_dict.remote() - root = tree_rep["root"] + root = tree.root h_node = root.children.get("h") - with pytest.raises(ValueError): - await tree._remove_tenant_single_node.remote("tenant_2", h_node) + # Should not raise error, just return 0 + removed = tree._remove_tenant_single_node("tenant_2", h_node) + assert removed == 0 -@pytest.mark.asyncio -async def test_evict_tenant_by_lru(): - """Test the evict_tenant_by_lru functionality of PrefixTree.""" - tree = serve.run(PrefixTree.bind()) +def test_remove_tenant(tree: PrefixTree) -> None: + """Test removing a tenant from the tree.""" + # 1. Test basic tenant removal + tree._reset() + tree.insert("hello", "tenant_1", 1) + removed: int = tree.remove_tenant("tenant_1") + assert removed == 5 - # 1. Test eviction with LRU ordering - await tree.reset.remote() - current_time = int(time.time() * 1000) - await tree.insert.remote("a", "tenant_1", current_time) - time.sleep(0.001) - current_time = int(time.time() * 1000) - await tree.insert.remote("bb", "tenant_1", current_time) - time.sleep(0.001) - current_time = int(time.time() * 1000) - await tree.insert.remote("ccc", "tenant_1", current_time) + assert "tenant_1" not in tree.tenants + assert "tenant_1" not in tree.tenant_char_count + assert "tenant_1" not in tree.tenant_nodes - tree_rep = await tree.to_dict.remote() - before = tree_rep["tenant_char_count"]["tenant_1"] + # 2. Test removing tenant with multiple nodes + tree._reset() + tree.insert("cat", "tenant_1", 1) + tree.insert("dog", "tenant_1", 2) + removed = tree.remove_tenant("tenant_1") + assert removed == len("cat") + len("dog") - evicted = await tree.evict_tenant_by_lru.remote("tenant_1", 2) + # 3. Test removing non-existent tenant is idempotent (logs warning, returns 0) + tree._reset() + # Should not raise error, just return 0 + removed = tree.remove_tenant("nonexistent_tenant") + assert removed == 0 - tree_rep = await tree.to_dict.remote() - after = tree_rep["tenant_char_count"]["tenant_1"] + # 4. Test tree structure after removing tenant + tree._reset() + tree.insert("hello", "tenant_1", 1) + tree.insert("hello", "tenant_2", 2) - assert evicted == 3 - assert before - after == evicted - assert "tenant_1" in tree_rep["tenants"] + # Remove tenant_1, verify tenant_2 still works + tree.remove_tenant("tenant_1") - # 2. Test eviction of non-existent tenant raises ValueError - await tree.reset.remote() - with pytest.raises(ValueError): - await tree.evict_tenant_by_lru.remote("nonexistent_tenant", 5) + assert "tenant_1" not in tree.tenants + assert "tenant_2" in tree.tenants - # 3. Test eviction of tenant with insufficient characters raises ValueError - await tree.reset.remote() - await tree.insert.remote("xyz", "tenant_2", int(time.time() * 1000)) - with pytest.raises(ValueError): - await tree.evict_tenant_by_lru.remote("tenant_2", 4) + matched_text, tenants = tree.prefix_match("hello") + assert matched_text == "hello" + assert tenants == ["tenant_2"] + + # 5. Test removing the last tenant from a node removes the node + tree._reset() + tree.insert("helloworld", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) - # 4. Test eviction of all tenant data - await tree.reset.remote() - await tree.insert.remote("xyz", "tenant_2", int(time.time() * 1000)) + # Remove tenant_1 + tree.remove_tenant("tenant_1") - tree_rep = await tree.to_dict.remote() - total_size = tree_rep["tenant_char_count"]["tenant_2"] + root: Node = tree.root + # 'h' node should only have one child now ('t' from hellothere) + assert "h" in root.children + assert "t" in root.children["h"].children + assert len(root.children["h"].children) == 1 - evicted = await tree.evict_tenant_by_lru.remote("tenant_2", total_size) - assert evicted == total_size - tree_rep = await tree.to_dict.remote() - assert "tenant_2" in tree_rep["tenants"] +def test_evict_tenant_by_lru(tree: PrefixTree) -> None: + """Test the evict_tenant_by_lru functionality of PrefixTree.""" - # 5. Test tree structure and LRU heap ordering - await tree.reset.remote() + # 1. Remove exactly min_remove_size characters + tree._reset() + tree.insert("a", "tenant_1", 1) + tree.insert("bb", "tenant_1", 2) + tree.insert("ccc", "tenant_1", 3) + + # Before eviction + char_count_before: int = tree.tenant_char_count["tenant_1"] + assert len(tree.tenant_nodes["tenant_1"]) == 4 + assert tree.tenant_char_count["tenant_1"] == 6 + + # During eviction + min_remove_size: int = 1 + evicted_count: int = tree.evict_tenant_by_lru("tenant_1", min_remove_size) + + # After eviction + char_count_after: int = tree.tenant_char_count["tenant_1"] + assert evicted_count == min_remove_size + assert char_count_before - char_count_after == evicted_count + assert len(tree.tenant_nodes["tenant_1"]) == 3 + assert tree.tenant_char_count["tenant_1"] == 5 + + # 2. Remove more than min_remove_size characters + tree._reset() + tree.insert("a", "tenant_1", 1) + tree.insert("bb", "tenant_1", 2) + tree.insert("ccc", "tenant_1", 3) + + # Before eviction + char_count_before = tree.tenant_char_count["tenant_1"] + assert len(tree.tenant_nodes["tenant_1"]) == 4 + assert tree.tenant_char_count["tenant_1"] == 6 + + # During eviction + min_remove_size = 2 + evicted_count = tree.evict_tenant_by_lru("tenant_1", min_remove_size) + + # After eviction + char_count_after = tree.tenant_char_count["tenant_1"] + assert evicted_count != min_remove_size and evicted_count == 3 + assert char_count_before - char_count_after == evicted_count + assert len(tree.tenant_nodes["tenant_1"]) == 2 + assert tree.tenant_char_count["tenant_1"] == 3 + + # 3. Test eviction of non-existent tenant is idempotent + tree._reset() + # Should not raise error, just return 0 + evicted_count = tree.evict_tenant_by_lru("nonexistent_tenant", 5) + assert evicted_count == 0 + + # 4. Test eviction of tenant with insufficient characters is idempotent + tree._reset() + tree.insert("xyz", "tenant_1", 1) + # Should not raise error, should evict all available characters + evicted_count = tree.evict_tenant_by_lru("tenant_1", 4) + assert evicted_count == 3 # "xyz" has 3 characters + + # 5. Test eviction of all tenant data + tree._reset() + tree.insert("xyz", "tenant_1", 1) + + total_size: int = tree.tenant_char_count["tenant_1"] + + evicted_count = tree.evict_tenant_by_lru("tenant_1", total_size) + assert evicted_count == total_size + + # "tenant_1" should still be in tenants + assert "tenant_1" in tree.tenants + + # 6. Test tree structure and LRU heap ordering + tree._reset() # Insert strings in specified order - await tree.insert.remote("helloworld", "tenant_1", 1) # time 1 for tenant_1 - await tree.insert.remote("hellothere", "tenant_2", 2) # time 2 for tenant_2 - await tree.insert.remote("hellothomas", "tenant_2", 3) # time 3 for tenant_2 + tree.insert("helloworld", "tenant_1", 1) # time 1 for tenant_1 + tree.insert("hellothere", "tenant_2", 2) # time 2 for tenant_2 + tree.insert("hellothomas", "tenant_2", 3) # time 3 for tenant_2 - # Get tree representation for testing - tree_rep = await tree.to_dict.remote() - root = tree_rep["root"] + # Access tree directly + root: Node = tree.root # Test tree structure - validate each node # Root node @@ -333,86 +448,84 @@ async def test_evict_tenant_by_lru(): assert "h" in root.children # Hello node - hello_node = root.children["h"] + hello_node: Node = root.children["h"] assert hello_node.text == "hello" assert hello_node.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} assert "w" in hello_node.children assert "t" in hello_node.children # World node - world_node = hello_node.children["w"] + world_node: Node = hello_node.children["w"] assert world_node.text == "world" assert world_node.tenant_last_access_time == {"tenant_1": 1} assert len(world_node.children) == 0 # Th node - th_node = hello_node.children["t"] + th_node: Node = hello_node.children["t"] assert th_node.text == "th" assert th_node.tenant_last_access_time == {"tenant_2": 3} assert "e" in th_node.children assert "o" in th_node.children # Ere node - ere_node = th_node.children["e"] + ere_node: Node = th_node.children["e"] assert ere_node.text == "ere" assert ere_node.tenant_last_access_time == {"tenant_2": 2} assert len(ere_node.children) == 0 # Omas node - omas_node = th_node.children["o"] + omas_node: Node = th_node.children["o"] assert omas_node.text == "omas" assert omas_node.tenant_last_access_time == {"tenant_2": 3} assert len(omas_node.children) == 0 # Test PrefixTree instance variables - assert tree_rep["tenants"] == {"tenant_1", "tenant_2"} + assert tree.tenants == {"tenant_1", "tenant_2"} # Test tenant_char_count - assert tree_rep["tenant_char_count"]["tenant_1"] == 10 # root(0) + hello(5) + world(5) = 10 - assert tree_rep["tenant_char_count"]["tenant_2"] == 14 # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 + assert tree.tenant_char_count["tenant_1"] == 10 # root(0) + hello(5) + world(5) = 10 + assert tree.tenant_char_count["tenant_2"] == 14 # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 # Test tenant_nodes (check by text) - tenant1_nodes_texts = {node.text for node in tree_rep["tenant_nodes"]["tenant_1"]} + tenant1_nodes_texts: Set[str] = {node.text for node in tree.tenant_nodes["tenant_1"]} assert tenant1_nodes_texts == {"", "hello", "world"} - tenant2_nodes_texts = {node.text for node in tree_rep["tenant_nodes"]["tenant_2"]} + tenant2_nodes_texts: Set[str] = {node.text for node in tree.tenant_nodes["tenant_2"]} assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} # Test tenant_nodes_sorted - validate heap ordering - assert heapq.heappop(tree_rep["tenant_nodes_sorted"]["tenant_1"]).node.tenant_last_access_time["tenant_1"] == 1 - assert heapq.heappop(tree_rep["tenant_nodes_sorted"]["tenant_1"]).node.tenant_last_access_time["tenant_1"] == 1 - assert heapq.heappop(tree_rep["tenant_nodes_sorted"]["tenant_2"]).node.tenant_last_access_time["tenant_2"] == 2 - assert heapq.heappop(tree_rep["tenant_nodes_sorted"]["tenant_2"]).node.tenant_last_access_time["tenant_2"] == 3 + tenant1_heap: List[TenantHeapNode] = tree.tenant_nodes_sorted["tenant_1"] + tenant2_heap: List[TenantHeapNode] = tree.tenant_nodes_sorted["tenant_2"] + + assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 + assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 + assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 2 + assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 3 -@pytest.mark.asyncio -async def test_get_smallest_tenant(): +def test_get_smallest_tenant(tree: PrefixTree) -> None: """Test the get_smallest_tenant functionality of PrefixTree.""" - tree = serve.run(PrefixTree.bind()) - # 1. Test with empty tree - await tree.reset.remote() - smallest = await tree.get_smallest_tenant.remote() + tree._reset() + smallest: Optional[str] = tree.get_smallest_tenant() assert smallest is None # 2. Test with multiple tenants of different sizes - await tree.reset.remote() - current_time = int(time.time() * 1000) - await tree.insert.remote("aaaa", "tenant_1", current_time) - await tree.insert.remote("bb", "tenant_2", current_time) - await tree.insert.remote("c", "tenant_3", current_time) + tree._reset() + tree.insert("aaaa", "tenant_1", 1) + tree.insert("bb", "tenant_2", 2) + tree.insert("c", "tenant_3", 3) - smallest = await tree.get_smallest_tenant.remote() + smallest = tree.get_smallest_tenant() assert smallest == "tenant_3" # 3. Test after removing the smallest tenant - await tree.reset.remote() - current_time = int(time.time() * 1000) - await tree.insert.remote("aaaa", "tenant_1", current_time) - await tree.insert.remote("bb", "tenant_2", current_time) - await tree.insert.remote("c", "tenant_3", current_time) - await tree.remove_tenant.remote("tenant_3") - smallest = await tree.get_smallest_tenant.remote() + tree._reset() + tree.insert("aaaa", "tenant_1", 1) + tree.insert("bb", "tenant_2", 2) + tree.insert("c", "tenant_3", 3) + tree.remove_tenant("tenant_3") + smallest = tree.get_smallest_tenant() assert smallest == "tenant_2" From 3e2e393a864a59c967dbc36f6e891c734661ce23 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Tue, 6 May 2025 18:22:22 -0700 Subject: [PATCH 09/15] linting Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 58 ++++---- .../serve/cpu/deployments/test_prefix_tree.py | 130 +++++++++++------- 2 files changed, 111 insertions(+), 77 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index e3fd81cd139cb..eedbcd417311d 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -33,9 +33,11 @@ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: parent: The parent node of this node """ self.text: str = text - self.parent: Optional[Node] = parent # The parent node of this node + self.parent: Optional[Node] = parent # The parent node of this node self.children: Dict[str, Node] = {} # Maps first character to child node - self.tenant_last_access_time: Dict[str, int] = ( + self.tenant_last_access_time: Dict[ + str, int + ] = ( {} ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in milliseconds) @@ -95,14 +97,14 @@ class PrefixTree: - "helloworld" at time 1 by tenant_1 - "hellothere" at time 2 by tenant_2 - "hellothomas" at time 3 by tenant_2 - + root: [] {tenant_1: 1, tenant_2: 3} (h) → [hello] {tenant_1: 1, tenant_2: 3} (w) → [world] {tenant_1: 1} (t) → [th] {tenant_2: 3} (e) → [ere] {tenant_2: 2} (o) → [omas] {tenant_2: 3} - + Legend for each node: - [text] = Node.text - {tenant, timestamp} = Node.tenant_last_access_time @@ -121,17 +123,20 @@ def __init__(self) -> None: self.lock: RLock = RLock() self.root: Node = Node() self.tenants: Set[str] = set() # Set of tenant IDs in the tree - self.tenant_char_count: Dict[str, int] = ( - {} - ) # Tracks total character count per tenant - self.tenant_nodes: Dict[str, Set[Node]] = ( + self.tenant_char_count: Dict[ + str, int + ] = {} # Tracks total character count per tenant + self.tenant_nodes: Dict[ + str, Set[Node] + ] = ( {} ) # Maps tenant ID to set of nodes belonging to that tenant. Used for O(1) lookup of whether a node belongs to a tenant. - self.tenant_nodes_sorted: Dict[str, List[TenantHeapNode]] = ( + self.tenant_nodes_sorted: Dict[ + str, List[TenantHeapNode] + ] = ( {} ) # Maps tenant ID to heap of nodes for LRU eviction. Used for O(log n) insertion and eviction of LRU node. - @staticmethod def _shared_prefix_count(a: str, b: str) -> int: """ @@ -146,11 +151,10 @@ def _shared_prefix_count(a: str, b: str) -> int: """ return len(os.path.commonprefix([a, b])) - def _reset(self) -> None: """ Reset the tree to an empty state. - + Note: This method is intended to be used only in tests. """ with self.lock: @@ -160,7 +164,6 @@ def _reset(self) -> None: self.tenant_nodes = {} self.tenant_nodes_sorted = {} - def _add_tenant(self, tenant: str) -> None: """ Add a new tenant to the tree. @@ -180,7 +183,6 @@ def _add_tenant(self, tenant: str) -> None: self.tenant_nodes[tenant] = set() self.tenant_nodes_sorted[tenant] = [] - def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: """ Remove a tenant from a single node. @@ -221,7 +223,6 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: return removed_chars_len - def insert(self, text: str, tenant: str, time_sec: float) -> Node: """ Insert text into tree for a specific tenant. @@ -236,7 +237,7 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: Returns: The node that was inserted or updated - Note: + Note: Loop structure: 1. At the start of each iteration, curr_node is a node we potentially update. e.g. node.tenant_last_access_time[tenant], self.tenant_char_count, @@ -339,7 +340,6 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: return curr_node - def prefix_match( self, text: str, available_tenants: Optional[List[str]] = None ) -> Tuple[str, Optional[List[str]]]: @@ -406,7 +406,6 @@ def prefix_match( return matched_text, matching_tenants - def remove_tenant(self, tenant: str) -> int: """ Remove a tenant and all its nodes from the tree. @@ -433,7 +432,6 @@ def remove_tenant(self, tenant: str) -> int: return total_chars_removed - def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: """ Evict least recently used nodes for a tenant until minimum size is freed. @@ -467,19 +465,27 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: and self.tenant_nodes_sorted[tenant] ): # Get the minimum access time from the top of the heap - oldest_access_time = self.tenant_nodes_sorted[tenant][0].node.tenant_last_access_time[tenant] - + oldest_access_time = self.tenant_nodes_sorted[tenant][ + 0 + ].node.tenant_last_access_time[tenant] + # Remove all nodes with this same access time - while (self.tenant_nodes_sorted[tenant] and - self.tenant_nodes_sorted[tenant][0].node.tenant_last_access_time[tenant] == oldest_access_time): - heap_node: TenantHeapNode = heapq.heappop(self.tenant_nodes_sorted[tenant]) + while ( + self.tenant_nodes_sorted[tenant] + and self.tenant_nodes_sorted[tenant][ + 0 + ].node.tenant_last_access_time[tenant] + == oldest_access_time + ): + heap_node: TenantHeapNode = heapq.heappop( + self.tenant_nodes_sorted[tenant] + ) total_chars_removed += self._remove_tenant_single_node( tenant, heap_node.node ) return total_chars_removed - def get_smallest_tenant(self) -> Optional[str]: """ Get the tenant with the smallest total character count. @@ -511,4 +517,4 @@ def _to_dict(self) -> Dict[str, Any]: "tenant_char_count": self.tenant_char_count, "tenant_nodes": self.tenant_nodes, "tenant_nodes_sorted": self.tenant_nodes_sorted, - } \ No newline at end of file + } diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index abaea937871b4..f8409f531f64e 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -1,20 +1,24 @@ import pytest -import time import ray from ray import serve import heapq -from typing import Set, List, Dict, Optional, Generator, Any +from typing import Set, List, Dict, Optional, Generator from ray.llm._internal.serve.replica_scheduler.prefix_aware.prefix_tree import ( - PrefixTree, PrefixTreeDeployment, Node, TenantHeapNode + PrefixTree, + PrefixTreeDeployment, + Node, + TenantHeapNode, ) + # Fixtures @pytest.fixture def tree() -> PrefixTree: """Create a fresh PrefixTree instance for each test.""" return PrefixTree() + @pytest.fixture(scope="module", autouse=True) def serve_instance() -> Generator[None, None, None]: # Start Ray and Serve once per test module @@ -24,84 +28,100 @@ def serve_instance() -> Generator[None, None, None]: serve.shutdown() ray.shutdown() + @pytest.fixture(scope="module") def tree_deployment(): """Create a fresh PrefixTreeDeployment instance for each test.""" tree = serve.run(PrefixTreeDeployment.bind()) return tree + # PrefixTreeDeployment tests @pytest.mark.asyncio async def test_tree_deployment(tree_deployment) -> None: """Test the PrefixTreeDeployment.""" # 6. Test tree structure and LRU heap ordering await tree_deployment._reset.remote() - + # Insert strings in specified order - await tree_deployment.insert.remote("helloworld", "tenant_1", 1) # time 1 for tenant_1 - await tree_deployment.insert.remote("hellothere", "tenant_2", 2) # time 2 for tenant_2 - await tree_deployment.insert.remote("hellothomas", "tenant_2", 3) # time 3 for tenant_2 - + await tree_deployment.insert.remote( + "helloworld", "tenant_1", 1 + ) # time 1 for tenant_1 + await tree_deployment.insert.remote( + "hellothere", "tenant_2", 2 + ) # time 2 for tenant_2 + await tree_deployment.insert.remote( + "hellothomas", "tenant_2", 3 + ) # time 3 for tenant_2 + # Access tree directly tree_rep: Dict = await tree_deployment._to_dict.remote() root: Node = tree_rep["root"] - + # Test tree structure - validate each node # Root node assert root.text == "" assert root.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} assert "h" in root.children - + # Hello node hello_node: Node = root.children["h"] assert hello_node.text == "hello" assert hello_node.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} assert "w" in hello_node.children assert "t" in hello_node.children - + # World node world_node: Node = hello_node.children["w"] assert world_node.text == "world" assert world_node.tenant_last_access_time == {"tenant_1": 1} assert len(world_node.children) == 0 - + # Th node th_node: Node = hello_node.children["t"] assert th_node.text == "th" assert th_node.tenant_last_access_time == {"tenant_2": 3} assert "e" in th_node.children assert "o" in th_node.children - + # Ere node ere_node: Node = th_node.children["e"] assert ere_node.text == "ere" assert ere_node.tenant_last_access_time == {"tenant_2": 2} assert len(ere_node.children) == 0 - + # Omas node omas_node: Node = th_node.children["o"] assert omas_node.text == "omas" assert omas_node.tenant_last_access_time == {"tenant_2": 3} assert len(omas_node.children) == 0 - + # Test PrefixTree instance variables assert tree_rep["tenants"] == {"tenant_1", "tenant_2"} - + # Test tenant_char_count - assert tree_rep["tenant_char_count"]["tenant_1"] == 10 # root(0) + hello(5) + world(5) = 10 - assert tree_rep["tenant_char_count"]["tenant_2"] == 14 # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 - + assert ( + tree_rep["tenant_char_count"]["tenant_1"] == 10 + ) # root(0) + hello(5) + world(5) = 10 + assert ( + tree_rep["tenant_char_count"]["tenant_2"] == 14 + ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 + # Test tenant_nodes (check by text) - tenant1_nodes_texts: Set[str] = {node.text for node in tree_rep["tenant_nodes"]["tenant_1"]} + tenant1_nodes_texts: Set[str] = { + node.text for node in tree_rep["tenant_nodes"]["tenant_1"] + } assert tenant1_nodes_texts == {"", "hello", "world"} - - tenant2_nodes_texts: Set[str] = {node.text for node in tree_rep["tenant_nodes"]["tenant_2"]} + + tenant2_nodes_texts: Set[str] = { + node.text for node in tree_rep["tenant_nodes"]["tenant_2"] + } assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} - + # Test tenant_nodes_sorted - validate heap ordering tenant1_heap: List[TenantHeapNode] = tree_rep["tenant_nodes_sorted"]["tenant_1"] tenant2_heap: List[TenantHeapNode] = tree_rep["tenant_nodes_sorted"]["tenant_2"] - + assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 2 @@ -160,37 +180,37 @@ def test_insert(tree: PrefixTree) -> None: assert h_node.text == "hello" assert h_node.children.get("w").text == "world" assert h_node.children.get("t").text == "there" - + # 4. Test that inserting a longer prompt with shared prefix doesn't create empty text nodes tree._reset() tree.insert("hello", "tenant_1", 1) tree.insert("helloworld", "tenant_2", 2) - + root = tree.root - + # Check that only the root has empty text by directly traversing the tree # Starting from root, collect all nodes with empty text empty_text_nodes: List[Node] = [] nodes_to_check: List[Node] = [root] - + while nodes_to_check: node: Node = nodes_to_check.pop() if node.text == "": empty_text_nodes.append(node) # Add all children to check nodes_to_check.extend(node.children.values()) - + # There should be exactly one empty text node (the root) assert len(empty_text_nodes) == 1 assert root in empty_text_nodes - + # Verify tree structure h_node = root.children.get("h") assert h_node is not None assert h_node.text == "hello" assert "tenant_1" in h_node.tenant_last_access_time assert "tenant_2" in h_node.tenant_last_access_time - + # Verify "world" node belongs only to tenant 2 world_node: Optional[Node] = h_node.children.get("w") assert world_node is not None @@ -366,7 +386,7 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: tree.insert("a", "tenant_1", 1) tree.insert("bb", "tenant_1", 2) tree.insert("ccc", "tenant_1", 3) - + # Before eviction char_count_before: int = tree.tenant_char_count["tenant_1"] assert len(tree.tenant_nodes["tenant_1"]) == 4 @@ -388,7 +408,7 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: tree.insert("a", "tenant_1", 1) tree.insert("bb", "tenant_1", 2) tree.insert("ccc", "tenant_1", 3) - + # Before eviction char_count_before = tree.tenant_char_count["tenant_1"] assert len(tree.tenant_nodes["tenant_1"]) == 4 @@ -432,71 +452,79 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: # 6. Test tree structure and LRU heap ordering tree._reset() - + # Insert strings in specified order tree.insert("helloworld", "tenant_1", 1) # time 1 for tenant_1 tree.insert("hellothere", "tenant_2", 2) # time 2 for tenant_2 tree.insert("hellothomas", "tenant_2", 3) # time 3 for tenant_2 - + # Access tree directly root: Node = tree.root - + # Test tree structure - validate each node # Root node assert root.text == "" assert root.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} assert "h" in root.children - + # Hello node hello_node: Node = root.children["h"] assert hello_node.text == "hello" assert hello_node.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} assert "w" in hello_node.children assert "t" in hello_node.children - + # World node world_node: Node = hello_node.children["w"] assert world_node.text == "world" assert world_node.tenant_last_access_time == {"tenant_1": 1} assert len(world_node.children) == 0 - + # Th node th_node: Node = hello_node.children["t"] assert th_node.text == "th" assert th_node.tenant_last_access_time == {"tenant_2": 3} assert "e" in th_node.children assert "o" in th_node.children - + # Ere node ere_node: Node = th_node.children["e"] assert ere_node.text == "ere" assert ere_node.tenant_last_access_time == {"tenant_2": 2} assert len(ere_node.children) == 0 - + # Omas node omas_node: Node = th_node.children["o"] assert omas_node.text == "omas" assert omas_node.tenant_last_access_time == {"tenant_2": 3} assert len(omas_node.children) == 0 - + # Test PrefixTree instance variables assert tree.tenants == {"tenant_1", "tenant_2"} - + # Test tenant_char_count - assert tree.tenant_char_count["tenant_1"] == 10 # root(0) + hello(5) + world(5) = 10 - assert tree.tenant_char_count["tenant_2"] == 14 # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 - + assert ( + tree.tenant_char_count["tenant_1"] == 10 + ) # root(0) + hello(5) + world(5) = 10 + assert ( + tree.tenant_char_count["tenant_2"] == 14 + ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 + # Test tenant_nodes (check by text) - tenant1_nodes_texts: Set[str] = {node.text for node in tree.tenant_nodes["tenant_1"]} + tenant1_nodes_texts: Set[str] = { + node.text for node in tree.tenant_nodes["tenant_1"] + } assert tenant1_nodes_texts == {"", "hello", "world"} - - tenant2_nodes_texts: Set[str] = {node.text for node in tree.tenant_nodes["tenant_2"]} + + tenant2_nodes_texts: Set[str] = { + node.text for node in tree.tenant_nodes["tenant_2"] + } assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} - + # Test tenant_nodes_sorted - validate heap ordering tenant1_heap: List[TenantHeapNode] = tree.tenant_nodes_sorted["tenant_1"] tenant2_heap: List[TenantHeapNode] = tree.tenant_nodes_sorted["tenant_2"] - + assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 2 From 73743692f8ede7993d9afcae8af971a2644be96c Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Wed, 7 May 2025 13:04:12 -0700 Subject: [PATCH 10/15] remove unnecessary instance variables Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 264 +++++++------- .../serve/cpu/deployments/test_prefix_tree.py | 331 +++++++++--------- 2 files changed, 298 insertions(+), 297 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index eedbcd417311d..544686164410f 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -16,9 +16,9 @@ class Node: Node in a prefix tree that represents a segment of text and can belong to multiple tenants. Each node also tracks the last access time for each tenant. Simple example of root node connected to two children Nodes: - root = Node(text="", parent=None, children={"f": fooNode, "b": barNode}, tenant_last_access_time={"tenant_1": 2}) - fooNode = Node(text="foo", parent=root, children={}, tenant_last_access_time={"tenant_1": 1}) - barNode = Node(text="bar", parent=root, children={}, tenant_last_access_time={"tenant_1": 2}) + root = Node(text="", parent=None, edge_label_to_child={"f": fooNode, "b": barNode}, tenant_to_last_access_time={"tenant_1": 2}) + fooNode = Node(text="foo", parent=root, edge_label_to_child={}, tenant_to_last_access_time={"tenant_1": 1}) + barNode = Node(text="bar", parent=root, edge_label_to_child={}, tenant_to_last_access_time={"tenant_1": 2}) In the above example, "foo" was inserted at time 1, and "bar" was inserted at time 2. It follows that root was last accessed at time 2. @@ -34,17 +34,13 @@ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: """ self.text: str = text self.parent: Optional[Node] = parent # The parent node of this node - self.children: Dict[str, Node] = {} # Maps first character to child node - self.tenant_last_access_time: Dict[ + self.edge_label_to_child: Dict[str, Node] = {} # Maps first character to child node + self.tenant_to_last_access_time: Dict[ str, int ] = ( {} ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in milliseconds) - def __repr__(self) -> str: - return f"Node(text='{self.text}', children={list(self.children.keys())}, tenants={list(self.tenant_last_access_time.keys())})" - - class TenantHeapNode: """ Wrapper class for storing nodes in a min-heap, ordered by tenant access time. @@ -57,7 +53,7 @@ def __init__(self, node: Node, tenant_ordering_key: str) -> None: Args: node: The prefix tree node this heap node refers to - tenant: The tenant ID this heap node is associated with + tenant_ordering_key: The tenant this heap uses to order nodes """ self.node = node self.tenant_ordering_key = tenant_ordering_key @@ -73,13 +69,10 @@ def __lt__(self, other: TenantHeapNode) -> bool: True if this node's tenant access time is earlier than the other's """ return ( - self.node.tenant_last_access_time[self.tenant_ordering_key] - < other.node.tenant_last_access_time[other.tenant_ordering_key] + self.node.tenant_to_last_access_time[self.tenant_ordering_key] + < other.node.tenant_to_last_access_time[other.tenant_ordering_key] ) - def __repr__(self) -> str: - return f"TenantHeapNode(node={self.node}, tenant_ordering_key={self.tenant_ordering_key})" - class PrefixTree: """ @@ -107,35 +100,24 @@ class PrefixTree: Legend for each node: - [text] = Node.text - - {tenant, timestamp} = Node.tenant_last_access_time + - {tenant, timestamp} = Node.tenant_to_last_access_time - (x) = edge label (first character used as key for parent's children) PrefixTree instance variables: - self.tenants = {"tenant_1", "tenant_2"} - self.tenant_char_count = {"tenant_1": 10, "tenant_2": 14} - self.tenant_nodes = {"tenant_1": {root, Node("hello"), Node("world")}, "tenant_2": {root, Node("hello"), Node("th"), Node("ere"), Node("omas")}} - self.tenant_nodes_sorted = {"tenant_1": [root, Node("hello"), Node("world")], "tenant_2": [Node("ere"), root, Node("hello"), Node("th"), Node("omas")]} - # Note: self.tenant_nodes_sorted is maintained as a min-heap, so the first element is guaranteed to be the least recently used node for that tenant, but the rest of the heap is not guaranteed to be sorted. + self.tenant_to_char_count = {"tenant_1": 10, "tenant_2": 14} + self.tenant_to_nodes = {"tenant_1": {root, Node("hello"), Node("world")}, "tenant_2": {root, Node("hello"), Node("th"), Node("ere"), Node("omas")}} """ def __init__(self) -> None: """Initialize an empty prefix tree.""" self.lock: RLock = RLock() self.root: Node = Node() - self.tenants: Set[str] = set() # Set of tenant IDs in the tree - self.tenant_char_count: Dict[ + self.tenant_to_char_count: Dict[ str, int - ] = {} # Tracks total character count per tenant - self.tenant_nodes: Dict[ + ] = {} # Tracks total character count per tenant. Used by the client to determine which tenant to evict, and by how much. + self.tenant_to_nodes: Dict[ str, Set[Node] - ] = ( - {} - ) # Maps tenant ID to set of nodes belonging to that tenant. Used for O(1) lookup of whether a node belongs to a tenant. - self.tenant_nodes_sorted: Dict[ - str, List[TenantHeapNode] - ] = ( - {} - ) # Maps tenant ID to heap of nodes for LRU eviction. Used for O(log n) insertion and eviction of LRU node. + ] = {} # Maps tenant to set of nodes. Used for O(1) testing if a node belongs to a tenant. The keys are the active tenants in the tree. @staticmethod def _shared_prefix_count(a: str, b: str) -> int: @@ -159,10 +141,8 @@ def _reset(self) -> None: """ with self.lock: self.root = Node() - self.tenants = set() - self.tenant_char_count = {} - self.tenant_nodes = {} - self.tenant_nodes_sorted = {} + self.tenant_to_char_count = {} + self.tenant_to_nodes = {} def _add_tenant(self, tenant: str) -> None: """ @@ -171,55 +151,62 @@ def _add_tenant(self, tenant: str) -> None: If the tenant already exists, this is a no-op with a warning log. Args: - tenant: Tenant ID to add + tenant: Tenant to add """ with self.lock: - if tenant in self.tenants: + if tenant in self.tenant_to_nodes: logger.warning(f"Tenant '{tenant}' already exists. No action taken.") return - self.tenants.add(tenant) - self.tenant_char_count[tenant] = 0 - self.tenant_nodes[tenant] = set() - self.tenant_nodes_sorted[tenant] = [] + self.tenant_to_char_count[tenant] = 0 + self.tenant_to_nodes[tenant] = set() def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: """ Remove a tenant from a single node. + This function expects valid input where: + - tenant exists in self.tenant_to_nodes + - tenant exists in node.tenant_to_last_access_time + - node exists in self.tenant_to_nodes[tenant] + + These preconditions are guaranteed to be satisfied if the user is using the public methods of this class. + They may be violated if the user manipulates the internal state of the tree directly. + Args: - tenant: Tenant ID to remove + tenant: Tenant to remove node: Node to remove tenant from + Does: + Decrements self.tenant_to_char_count[tenant] by the length of the node's text. + Removes the tenant from node.tenant_to_last_access_time. + Removes the node from self.tenant_to_nodes[tenant]. + Returns: - Number of characters removed + Number of characters removed (0 if preconditions not met) """ with self.lock: - if tenant not in self.tenants: + if tenant not in self.tenant_to_nodes: logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") return 0 - - if ( - node not in self.tenant_nodes[tenant] - or tenant not in node.tenant_last_access_time - ): - logger.warning( - f"Cannot remove node '{node.text}' from tenant '{tenant}': " - f"tenant does not have this node. No action taken." - ) + if tenant not in node.tenant_to_last_access_time: + logger.warning(f"Tenant '{tenant}' does not have node '{node.text}'. No action taken.") + return 0 + if node not in self.tenant_to_nodes[tenant]: + logger.warning(f"Node '{node.text}' does not belong to tenant '{tenant}'. No action taken.") return 0 removed_chars_len: int = len(node.text) - self.tenant_char_count[tenant] -= removed_chars_len - self.tenant_nodes[tenant].remove(node) - node.tenant_last_access_time.pop(tenant, None) + self.tenant_to_char_count[tenant] -= removed_chars_len + self.tenant_to_nodes[tenant].remove(node) + node.tenant_to_last_access_time.pop(tenant, None) # Clean up empty nodes - if not node.tenant_last_access_time and node.parent: + if not node.tenant_to_last_access_time and node.parent: if ( - node.text and node.text[0] in node.parent.children + node.text and node.text[0] in node.parent.edge_label_to_child ): # Defensive check - node.parent.children.pop(node.text[0], None) + node.parent.edge_label_to_child.pop(node.text[0], None) return removed_chars_len @@ -231,17 +218,16 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: Args: text: Text to insert - tenant: Tenant ID + tenant: Tenant time_sec: Current timestamp in seconds Returns: The node that was inserted or updated - Note: - Loop structure: + Loop structure: 1. At the start of each iteration, curr_node is a node we potentially update. - e.g. node.tenant_last_access_time[tenant], self.tenant_char_count, - self.tenant_nodes, self.tenant_nodes_sorted + e.g. Update node.tenant_to_last_access_time[tenant], self.tenant_to_char_count, + self.tenant_to_nodes 2. Each iteration then either: a. Breaks (if we've processed the entire string). b. Processes the next segment of text by: @@ -249,44 +235,37 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: 2. Then, match the current text with the child's text: a. If they share a prefix (partial match), split the node and traverse into the new parent. b. If they fully match, traverse into the child node. - 3. The self.tenant_nodes_sorted heap is reheapified at each node visit to maintain LRU order. - - This structure allows us to efficiently insert text while maintaining shared prefixes - and tracking tenant access times for the LRU eviction mechanism. """ with self.lock: - if tenant not in self.tenants: + if tenant not in self.tenant_to_nodes: self._add_tenant(tenant) curr_node: Node = self.root i: int = 0 while i <= len(text): - # Invariant: assume curr_node has not been visited by tenant yet - # Update tenant info for current node - if tenant not in curr_node.tenant_last_access_time: - self.tenant_char_count[tenant] += len(curr_node.text) - self.tenant_nodes[tenant].add(curr_node) - self.tenant_nodes_sorted[tenant].append( - TenantHeapNode(curr_node, tenant) - ) + # Invariant at beginning of each iteration: assume curr_node has not been visited by tenant yet. + # Update tenant info for current node. + if tenant not in curr_node.tenant_to_last_access_time: + self.tenant_to_char_count[tenant] += len(curr_node.text) + self.tenant_to_nodes[tenant].add(curr_node) + + curr_node.tenant_to_last_access_time[tenant] = time_sec - curr_node.tenant_last_access_time[tenant] = time_sec - heapq.heapify(self.tenant_nodes_sorted[tenant]) if i == len(text): break first_char: str = text[i] curr_text: str = text[i:] - if first_char not in curr_node.children: + if first_char not in curr_node.edge_label_to_child: # No match, create new node. Don't update new node as "visited" by tenant yet; it will be done in the code below. - # e.g. curr_node.children = {}, curr_text = "hello" -> curr_node.children = {"h": Node("hello")} + # e.g. curr_node.edge_label_to_child = {}, curr_text = "hello" -> curr_node.edge_label_to_child = {"h": Node("hello")} new_node: Node = Node(text=curr_text, parent=curr_node) - curr_node.children[first_char] = new_node + curr_node.edge_label_to_child[first_char] = new_node # Match found, check if we need to split - matched_node: Node = curr_node.children[first_char] + matched_node: Node = curr_node.edge_label_to_child[first_char] shared_count: int = self._shared_prefix_count( matched_node.text, curr_text ) @@ -295,16 +274,16 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: # Partial match, split node at matched point # Example: ## Before update: - ### curr_node.children = {"h": Node("helloworld")}, curr_text = "hellothere" -> shared_count = 5 + ### curr_node.edge_label_to_child = {"h": Node("helloworld")}, curr_text = "hellothere" -> shared_count = 5 ### matched_node = Node("helloworld") ## During update: - ### Increment tenant_char_count[tenant] by shared_count if matched_node has not seen this tenant before + ### Increment tenant_to_char_count[tenant] by shared_count if matched_node has not seen this tenant before ## After update: - ### curr_node.children = {"h": Node("hello", children = {"w": Node("world")})} + ### curr_node.edge_label_to_child = {"h": Node("hello", edge_label_to_child = {"w": Node("world")})} ### parent_node = Node("hello"), matched_node = Node("world") - ### Update tenant_last_access_time for parent_node, NOT matched_node + ### Update tenant_to_last_access_time for parent_node, NOT matched_node ### (new) curr_text = "there", (new) curr_node = parent_node ### Continue adding "there" to tree in next iteration @@ -313,22 +292,19 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: # Create new intermediate node new_parent: Node = Node(text=matched_text, parent=curr_node) - new_parent.tenant_last_access_time = ( - matched_node.tenant_last_access_time.copy() + new_parent.tenant_to_last_access_time = ( + matched_node.tenant_to_last_access_time.copy() ) - for existing_tenant in new_parent.tenant_last_access_time: - self.tenant_nodes[existing_tenant].add(new_parent) - self.tenant_nodes_sorted[existing_tenant].append( - TenantHeapNode(new_parent, existing_tenant) - ) + for existing_tenant in new_parent.tenant_to_last_access_time: + self.tenant_to_nodes[existing_tenant].add(new_parent) # Update existing matched node matched_node.text = remaining_text matched_node.parent = new_parent # Connect nodes - new_parent.children[remaining_text[0]] = matched_node - curr_node.children[first_char] = new_parent + new_parent.edge_label_to_child[remaining_text[0]] = matched_node + curr_node.edge_label_to_child[first_char] = new_parent # Continue traversal curr_node = new_parent @@ -351,17 +327,17 @@ def prefix_match( available_tenants: List of tenants to match against (or None for all) Returns: - Tuple of (matched_text, matched_tenant_ids) + Tuple of (matched_text, matched_tenants) """ if available_tenants: # Filter available_tenants to only include those in the tree available_tenants = [ - tenant for tenant in available_tenants if tenant in self.tenants + tenant for tenant in available_tenants if tenant in self.tenant_to_nodes ] if not available_tenants: return "", None else: - available_tenants = list(self.tenants) + available_tenants = list(self.tenant_to_nodes.keys()) with self.lock: curr_node: Node = self.root @@ -372,12 +348,12 @@ def prefix_match( first_char: str = text[i] curr_text: str = text[i:] - if first_char in curr_node.children: - matched_node: Node = curr_node.children[first_char] + if first_char in curr_node.edge_label_to_child: + matched_node: Node = curr_node.edge_label_to_child[first_char] # Check if any available tenants match this node if not any( - tenant in matched_node.tenant_last_access_time + tenant in matched_node.tenant_to_last_access_time for tenant in available_tenants ): break @@ -396,39 +372,37 @@ def prefix_match( break # Find tenants in current node that match available tenants - matching_tenants = [ + matched_tenants = [ tenant for tenant in available_tenants - if tenant in curr_node.tenant_last_access_time + if tenant in curr_node.tenant_to_last_access_time ] or None matched_text: str = text[:i] - return matched_text, matching_tenants + return matched_text, matched_tenants def remove_tenant(self, tenant: str) -> int: """ Remove a tenant and all its nodes from the tree. Args: - tenant: Tenant ID to remove + tenant: Tenant to remove Returns: Number of characters removed """ with self.lock: - if tenant not in self.tenants: + if tenant not in self.tenant_to_nodes: logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") return 0 total_chars_removed: int = 0 - for node in self.tenant_nodes[tenant].copy(): + for node in self.tenant_to_nodes[tenant].copy(): total_chars_removed += self._remove_tenant_single_node(tenant, node) - self.tenants.remove(tenant) - self.tenant_nodes.pop(tenant, None) - self.tenant_char_count.pop(tenant, None) - self.tenant_nodes_sorted.pop(tenant, None) + self.tenant_to_nodes.pop(tenant, None) + self.tenant_to_char_count.pop(tenant, None) return total_chars_removed @@ -442,46 +416,56 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: Returns: Actual number of characters removed + + Note: + - All nodes with the same oldest access time are removed together to maintain tree integrity, even if only removing a subset of them satisfies the min_remove_size. + - This behavior is expected in the case when an input was split into multiple nodes by a different tenant (e.g. insert("helloworld", "tenant_1", 1) and insert("hellothere", "tenant_2", 2)). + because there is no reason to only remove "world" from tenant 1. So we remove the "chain" of "hello" and "world" from tenant 1. + - However, if two inputs happen to be inserted at the same time (e.g. insert("helloworld", "tenant_1", 1) and insert("hellothere", "tenant_2", 1)), + then both "chains" will be removed by our method. This may not reflect the actual KV cache eviction policy. + - For more predictable eviction, use unique timestamps for each insertion. """ with self.lock: - if tenant not in self.tenant_nodes or not self.tenant_nodes[tenant]: + if tenant not in self.tenant_to_nodes or not self.tenant_to_nodes[tenant]: logger.warning( f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes. No action taken." ) return 0 - if self.tenant_char_count[tenant] < min_remove_size: + if self.tenant_to_char_count[tenant] < min_remove_size: logger.warning( f"Cannot evict {min_remove_size} characters from tenant '{tenant}', which has only " - f"{self.tenant_char_count[tenant]} characters. Will remove all available characters." + f"{self.tenant_to_char_count[tenant]} characters. Will remove all available characters." ) - min_remove_size = self.tenant_char_count[tenant] + min_remove_size = self.tenant_to_char_count[tenant] total_chars_removed: int = 0 - # Directly use the tenant's priority queue + # Create a min-heap of nodes ordered by access time + # Each entry is a tuple of (access_time, node) so heapq sorts by access_time first + nodes_by_access_time = [] + for node in self.tenant_to_nodes[tenant]: + access_time = node.tenant_to_last_access_time[tenant] + nodes_by_access_time.append((access_time, node)) + heapq.heapify(nodes_by_access_time) + + # Remove nodes until we've freed enough characters while ( total_chars_removed < min_remove_size - and self.tenant_nodes_sorted[tenant] + and nodes_by_access_time ): - # Get the minimum access time from the top of the heap - oldest_access_time = self.tenant_nodes_sorted[tenant][ - 0 - ].node.tenant_last_access_time[tenant] + # Get the oldest (minimum) access time from the top of the heap + oldest_access_time = nodes_by_access_time[0][0] - # Remove all nodes with this same access time + # Remove ALL nodes with this same access time to maintain tree consistency + # (partial removals could break prefix relationships) while ( - self.tenant_nodes_sorted[tenant] - and self.tenant_nodes_sorted[tenant][ - 0 - ].node.tenant_last_access_time[tenant] - == oldest_access_time + nodes_by_access_time + and nodes_by_access_time[0][0] == oldest_access_time ): - heap_node: TenantHeapNode = heapq.heappop( - self.tenant_nodes_sorted[tenant] - ) + _, node_to_remove = heapq.heappop(nodes_by_access_time) total_chars_removed += self._remove_tenant_single_node( - tenant, heap_node.node + tenant, node_to_remove ) return total_chars_removed @@ -491,13 +475,13 @@ def get_smallest_tenant(self) -> Optional[str]: Get the tenant with the smallest total character count. Returns: - Tenant ID with smallest character count, or None if no tenants + Tenant with smallest character count, or None if no tenants """ with self.lock: - if not self.tenant_char_count: + if not self.tenant_to_char_count: return None - return min(self.tenant_char_count, key=self.tenant_char_count.get, default=None) + return min(self.tenant_to_char_count, key=self.tenant_to_char_count.get, default=None) @serve.deployment(name="TreeDeployment") @@ -513,8 +497,6 @@ def _to_dict(self) -> Dict[str, Any]: """ return { "root": self.root, - "tenants": self.tenants, - "tenant_char_count": self.tenant_char_count, - "tenant_nodes": self.tenant_nodes, - "tenant_nodes_sorted": self.tenant_nodes_sorted, + "tenant_to_char_count": self.tenant_to_char_count, + "tenant_to_nodes": self.tenant_to_nodes, } diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index f8409f531f64e..235f4c3caadb5 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -8,7 +8,6 @@ PrefixTree, PrefixTreeDeployment, Node, - TenantHeapNode, ) @@ -61,72 +60,85 @@ async def test_tree_deployment(tree_deployment) -> None: # Test tree structure - validate each node # Root node assert root.text == "" - assert root.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} - assert "h" in root.children + assert root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "h" in root.edge_label_to_child # Hello node - hello_node: Node = root.children["h"] + hello_node: Node = root.edge_label_to_child["h"] assert hello_node.text == "hello" - assert hello_node.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} - assert "w" in hello_node.children - assert "t" in hello_node.children + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "w" in hello_node.edge_label_to_child + assert "t" in hello_node.edge_label_to_child # World node - world_node: Node = hello_node.children["w"] + world_node: Node = hello_node.edge_label_to_child["w"] assert world_node.text == "world" - assert world_node.tenant_last_access_time == {"tenant_1": 1} - assert len(world_node.children) == 0 + assert world_node.tenant_to_last_access_time == {"tenant_1": 1} + assert len(world_node.edge_label_to_child) == 0 # Th node - th_node: Node = hello_node.children["t"] + th_node: Node = hello_node.edge_label_to_child["t"] assert th_node.text == "th" - assert th_node.tenant_last_access_time == {"tenant_2": 3} - assert "e" in th_node.children - assert "o" in th_node.children + assert th_node.tenant_to_last_access_time == {"tenant_2": 3} + assert "e" in th_node.edge_label_to_child + assert "o" in th_node.edge_label_to_child # Ere node - ere_node: Node = th_node.children["e"] + ere_node: Node = th_node.edge_label_to_child["e"] assert ere_node.text == "ere" - assert ere_node.tenant_last_access_time == {"tenant_2": 2} - assert len(ere_node.children) == 0 + assert ere_node.tenant_to_last_access_time == {"tenant_2": 2} + assert len(ere_node.edge_label_to_child) == 0 # Omas node - omas_node: Node = th_node.children["o"] + omas_node: Node = th_node.edge_label_to_child["o"] assert omas_node.text == "omas" - assert omas_node.tenant_last_access_time == {"tenant_2": 3} - assert len(omas_node.children) == 0 + assert omas_node.tenant_to_last_access_time == {"tenant_2": 3} + assert len(omas_node.edge_label_to_child) == 0 # Test PrefixTree instance variables - assert tree_rep["tenants"] == {"tenant_1", "tenant_2"} + # Using tenant_to_nodes instead of tenants + assert set(tree_rep["tenant_to_nodes"].keys()) == {"tenant_1", "tenant_2"} - # Test tenant_char_count - assert ( - tree_rep["tenant_char_count"]["tenant_1"] == 10 - ) # root(0) + hello(5) + world(5) = 10 - assert ( - tree_rep["tenant_char_count"]["tenant_2"] == 14 - ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 - - # Test tenant_nodes (check by text) + # Test tenant_to_nodes (check by text) tenant1_nodes_texts: Set[str] = { - node.text for node in tree_rep["tenant_nodes"]["tenant_1"] + node.text for node in tree_rep["tenant_to_nodes"]["tenant_1"] } assert tenant1_nodes_texts == {"", "hello", "world"} tenant2_nodes_texts: Set[str] = { - node.text for node in tree_rep["tenant_nodes"]["tenant_2"] + node.text for node in tree_rep["tenant_to_nodes"]["tenant_2"] } assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} - # Test tenant_nodes_sorted - validate heap ordering - tenant1_heap: List[TenantHeapNode] = tree_rep["tenant_nodes_sorted"]["tenant_1"] - tenant2_heap: List[TenantHeapNode] = tree_rep["tenant_nodes_sorted"]["tenant_2"] - - assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 - assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 - assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 2 - assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 3 + # Test tenant_to_char_count + ## Before evictions + assert ( + tree_rep["tenant_to_char_count"]["tenant_1"] == 10 + ) # root(0) + hello(5) + world(5) = 10 + assert ( + tree_rep["tenant_to_char_count"]["tenant_2"] == 14 + ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 + ## After evicting tenant_1 with min_remove_size=1 + # Should remove both "hello" and "world" nodes (10 chars) since they have the same timestamp + evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_1", 1) + assert evicted_count == 10 # All 10 chars removed, not just 1 + tree_rep: Dict = await tree_deployment._to_dict.remote() + assert tree_rep["tenant_to_char_count"]["tenant_1"] == 0 + + ## After evicting tenant_2 with min_remove_size=1 + # Should remove "ere" node (3 chars) since it has the oldest timestamp (2) + evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_2", 1) + assert evicted_count == 3 # All 3 chars from "ere" removed + tree_rep: Dict = await tree_deployment._to_dict.remote() + assert tree_rep["tenant_to_char_count"]["tenant_2"] == 11 # 14 - 3 = 11 + + ## After evicting tenant_2 again with min_remove_size=1 + # Should remove "hello", "th", and "omas" nodes (11 chars) since they all have timestamp 3 + evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_2", 1) + assert evicted_count == 11 # All 11 remaining chars removed + tree_rep: Dict = await tree_deployment._to_dict.remote() + assert tree_rep["tenant_to_char_count"]["tenant_2"] == 0 # PrefixTree tests def test__add_tenant(tree: PrefixTree) -> None: @@ -134,9 +146,9 @@ def test__add_tenant(tree: PrefixTree) -> None: # 1. Test basic tenant addition tree._reset() tree._add_tenant("tenant_1") - assert "tenant_1" in tree.tenants - assert tree.tenant_char_count["tenant_1"] == 0 - assert tree.tenant_nodes["tenant_1"] == set() + assert "tenant_1" in tree.tenant_to_nodes + assert tree.tenant_to_char_count["tenant_1"] == 0 + assert tree.tenant_to_nodes["tenant_1"] == set() # 2. Test adding duplicate tenant logs warning but doesn't raise error tree._reset() @@ -144,7 +156,7 @@ def test__add_tenant(tree: PrefixTree) -> None: # This should be a no-op tree._add_tenant("tenant_1") # Verify the tenant still exists - assert "tenant_1" in tree.tenants + assert "tenant_1" in tree.tenant_to_nodes def test_insert(tree: PrefixTree) -> None: @@ -153,12 +165,12 @@ def test_insert(tree: PrefixTree) -> None: tree._reset() # No need to call add_tenant first - insert will do it automatically tree.insert("hello", "tenant_1", 1) - matched_text, tenants = tree.prefix_match("hello") + matched_text, matched_tenants = tree.prefix_match("hello") assert matched_text == "hello" - assert tenants == ["tenant_1"] + assert matched_tenants == ["tenant_1"] - assert tree.tenant_char_count["tenant_1"] == 5 - assert len(tree.tenant_nodes["tenant_1"]) == 2 + assert tree.tenant_to_char_count["tenant_1"] == 5 + assert len(tree.tenant_to_nodes["tenant_1"]) == 2 # 2. Test duplicate insertion doesn't double count tree._reset() @@ -166,8 +178,8 @@ def test_insert(tree: PrefixTree) -> None: tree.insert("foo", "tenant_1", 1) # duplicate tree.insert("bar", "tenant_2", 2) - assert tree.tenant_char_count["tenant_1"] == 3 - assert tree.tenant_char_count["tenant_2"] == 3 + assert tree.tenant_to_char_count["tenant_1"] == 3 + assert tree.tenant_to_char_count["tenant_2"] == 3 # 3. Test node splitting on partial match tree._reset() @@ -175,11 +187,11 @@ def test_insert(tree: PrefixTree) -> None: tree.insert("hellothere", "tenant_2", 2) root: Node = tree.root - h_node: Optional[Node] = root.children.get("h") + h_node: Optional[Node] = root.edge_label_to_child.get("h") assert h_node is not None assert h_node.text == "hello" - assert h_node.children.get("w").text == "world" - assert h_node.children.get("t").text == "there" + assert h_node.edge_label_to_child.get("w").text == "world" + assert h_node.edge_label_to_child.get("t").text == "there" # 4. Test that inserting a longer prompt with shared prefix doesn't create empty text nodes tree._reset() @@ -198,96 +210,94 @@ def test_insert(tree: PrefixTree) -> None: if node.text == "": empty_text_nodes.append(node) # Add all children to check - nodes_to_check.extend(node.children.values()) + nodes_to_check.extend(node.edge_label_to_child.values()) # There should be exactly one empty text node (the root) assert len(empty_text_nodes) == 1 assert root in empty_text_nodes # Verify tree structure - h_node = root.children.get("h") + h_node = root.edge_label_to_child.get("h") assert h_node is not None assert h_node.text == "hello" - assert "tenant_1" in h_node.tenant_last_access_time - assert "tenant_2" in h_node.tenant_last_access_time + assert "tenant_1" in h_node.tenant_to_last_access_time + assert "tenant_2" in h_node.tenant_to_last_access_time # Verify "world" node belongs only to tenant 2 - world_node: Optional[Node] = h_node.children.get("w") + world_node: Optional[Node] = h_node.edge_label_to_child.get("w") assert world_node is not None assert world_node.text == "world" - assert "tenant_2" in world_node.tenant_last_access_time - assert "tenant_1" not in world_node.tenant_last_access_time + assert "tenant_2" in world_node.tenant_to_last_access_time + assert "tenant_1" not in world_node.tenant_to_last_access_time # Verify the only child of h_node is "w" - assert len(h_node.children) == 1 + assert len(h_node.edge_label_to_child) == 1 def test_prefix_match(tree: PrefixTree) -> None: """Test the prefix_match functionality of PrefixTree.""" # 1. Test no match tree._reset() - matched_text, tenants = tree.prefix_match("hello") + matched_text, matched_tenants = tree.prefix_match("hello") assert matched_text == "" - assert tenants is None + assert matched_tenants is None # 2. Test match with non-existing prefix returns empty string and all tenants tree._reset() tree.insert("hello", "tenant_1", 1) tree.insert("hellothere", "tenant_2", 2) - matched_text, tenants = tree.prefix_match("foobar") + matched_text, matched_tenants = tree.prefix_match("foobar") assert matched_text == "" - assert len(tenants) == 2 - assert "tenant_1" in tenants - assert "tenant_2" in tenants + assert matched_tenants == ["tenant_1", "tenant_2"] # 3. Test exact match tree._reset() tree.insert("hello", "tenant_1", 1) - matched_text, tenants = tree.prefix_match("hello") + matched_text, matched_tenants = tree.prefix_match("hello") assert matched_text == "hello" - assert tenants == ["tenant_1"] + assert matched_tenants == ["tenant_1"] # 4. Test partial match tree._reset() tree.insert("apple", "tenant_1", 1) tree.insert("apricot", "tenant_2", 2) - text, tenants = tree.prefix_match("application") - assert text == "appl" - assert tenants == ["tenant_1"] + matched_text, matched_tenants = tree.prefix_match("application") + assert matched_text == "appl" + assert matched_tenants == ["tenant_1"] # 5. Test match by tenant tree._reset() tree.insert("apple", "tenant_1", 1) tree.insert("apricot", "tenant_2", 2) - text, tenants = tree.prefix_match("application", ["tenant_2"]) - assert text == "ap" - assert tenants == ["tenant_2"] + matched_text, matched_tenants = tree.prefix_match("application", ["tenant_2"]) + assert matched_text == "ap" + assert matched_tenants == ["tenant_2"] # 6. Test match by non-existent tenant tree._reset() tree.insert("apple", "tenant_1", 1) tree.insert("apricot", "tenant_2", 2) - text, tenants = tree.prefix_match("application", ["tenant_3"]) - assert text == "" - assert tenants is None + matched_text, matched_tenants = tree.prefix_match("application", ["tenant_3"]) + assert matched_text == "" + assert matched_tenants is None # 7. Test shared prefix matching with branches tree._reset() tree.insert("helloworld", "tenant_1", 1) tree.insert("hellothere", "tenant_2", 2) - text_a, tenants_a = tree.prefix_match("helloworld") - text_b, tenants_b = tree.prefix_match("hellothereworld") - assert text_a == "helloworld" - assert tenants_a == ["tenant_1"] - assert text_b == "hellothere" - assert tenants_b == ["tenant_2"] + + matched_text, matched_tenants = tree.prefix_match("helloworld") + assert matched_text == "helloworld" + assert matched_tenants == ["tenant_1"] + + matched_text, matched_tenants = tree.prefix_match("hellothereworld") + assert matched_text == "hellothere" + assert matched_tenants == ["tenant_2"] def test__remove_tenant_single_node(tree: PrefixTree) -> None: """Test removing a single node for a tenant.""" # 1. Test removing a single node - # TEST FAILS: Ray creates new node instances when making remote calls? - # The node from insert.remote() is not identity-equal to the one in tenant_nodes tree._reset() tree.insert("hello", "tenant_1", 1) @@ -296,15 +306,15 @@ def test__remove_tenant_single_node(tree: PrefixTree) -> None: removed: int = tree._remove_tenant_single_node("tenant_1", h_node) assert removed == 5 - assert tree.tenant_char_count["tenant_1"] == 0 - assert len(tree.tenant_nodes["tenant_1"]) == 1 - assert tree.root in tree.tenant_nodes["tenant_1"] + assert tree.tenant_to_char_count["tenant_1"] == 0 + assert len(tree.tenant_to_nodes["tenant_1"]) == 1 + assert tree.root in tree.tenant_to_nodes["tenant_1"] # 2. Test removing node for non-existent tenant is idempotent tree._reset() tree.insert("hello", "tenant_1", 1) root: Node = tree.root - h_node: Optional[Node] = root.children.get("h") + h_node: Optional[Node] = root.edge_label_to_child.get("h") # Should not raise error, just return 0 removed = tree._remove_tenant_single_node("nonexistent_tenant", h_node) @@ -316,7 +326,7 @@ def test__remove_tenant_single_node(tree: PrefixTree) -> None: tree.insert("world", "tenant_2", 2) root = tree.root - h_node = root.children.get("h") + h_node = root.edge_label_to_child.get("h") # Should not raise error, just return 0 removed = tree._remove_tenant_single_node("tenant_2", h_node) @@ -331,9 +341,8 @@ def test_remove_tenant(tree: PrefixTree) -> None: removed: int = tree.remove_tenant("tenant_1") assert removed == 5 - assert "tenant_1" not in tree.tenants - assert "tenant_1" not in tree.tenant_char_count - assert "tenant_1" not in tree.tenant_nodes + assert "tenant_1" not in tree.tenant_to_nodes + assert "tenant_1" not in tree.tenant_to_char_count # 2. Test removing tenant with multiple nodes tree._reset() @@ -356,12 +365,12 @@ def test_remove_tenant(tree: PrefixTree) -> None: # Remove tenant_1, verify tenant_2 still works tree.remove_tenant("tenant_1") - assert "tenant_1" not in tree.tenants - assert "tenant_2" in tree.tenants + assert "tenant_1" not in tree.tenant_to_nodes + assert "tenant_2" in tree.tenant_to_nodes - matched_text, tenants = tree.prefix_match("hello") + matched_text, matched_tenants = tree.prefix_match("hello") assert matched_text == "hello" - assert tenants == ["tenant_2"] + assert matched_tenants == ["tenant_2"] # 5. Test removing the last tenant from a node removes the node tree._reset() @@ -373,9 +382,9 @@ def test_remove_tenant(tree: PrefixTree) -> None: root: Node = tree.root # 'h' node should only have one child now ('t' from hellothere) - assert "h" in root.children - assert "t" in root.children["h"].children - assert len(root.children["h"].children) == 1 + assert "h" in root.edge_label_to_child + assert "t" in root.edge_label_to_child["h"].edge_label_to_child + assert len(root.edge_label_to_child["h"].edge_label_to_child) == 1 def test_evict_tenant_by_lru(tree: PrefixTree) -> None: @@ -388,20 +397,20 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: tree.insert("ccc", "tenant_1", 3) # Before eviction - char_count_before: int = tree.tenant_char_count["tenant_1"] - assert len(tree.tenant_nodes["tenant_1"]) == 4 - assert tree.tenant_char_count["tenant_1"] == 6 + char_count_before: int = tree.tenant_to_char_count["tenant_1"] + assert len(tree.tenant_to_nodes["tenant_1"]) == 4 + assert tree.tenant_to_char_count["tenant_1"] == 6 # During eviction min_remove_size: int = 1 evicted_count: int = tree.evict_tenant_by_lru("tenant_1", min_remove_size) # After eviction - char_count_after: int = tree.tenant_char_count["tenant_1"] + char_count_after: int = tree.tenant_to_char_count["tenant_1"] assert evicted_count == min_remove_size assert char_count_before - char_count_after == evicted_count - assert len(tree.tenant_nodes["tenant_1"]) == 3 - assert tree.tenant_char_count["tenant_1"] == 5 + assert len(tree.tenant_to_nodes["tenant_1"]) == 3 + assert tree.tenant_to_char_count["tenant_1"] == 5 # 2. Remove more than min_remove_size characters tree._reset() @@ -410,20 +419,20 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: tree.insert("ccc", "tenant_1", 3) # Before eviction - char_count_before = tree.tenant_char_count["tenant_1"] - assert len(tree.tenant_nodes["tenant_1"]) == 4 - assert tree.tenant_char_count["tenant_1"] == 6 + char_count_before = tree.tenant_to_char_count["tenant_1"] + assert len(tree.tenant_to_nodes["tenant_1"]) == 4 + assert tree.tenant_to_char_count["tenant_1"] == 6 # During eviction min_remove_size = 2 evicted_count = tree.evict_tenant_by_lru("tenant_1", min_remove_size) # After eviction - char_count_after = tree.tenant_char_count["tenant_1"] + char_count_after = tree.tenant_to_char_count["tenant_1"] assert evicted_count != min_remove_size and evicted_count == 3 assert char_count_before - char_count_after == evicted_count - assert len(tree.tenant_nodes["tenant_1"]) == 2 - assert tree.tenant_char_count["tenant_1"] == 3 + assert len(tree.tenant_to_nodes["tenant_1"]) == 2 + assert tree.tenant_to_char_count["tenant_1"] == 3 # 3. Test eviction of non-existent tenant is idempotent tree._reset() @@ -442,15 +451,15 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: tree._reset() tree.insert("xyz", "tenant_1", 1) - total_size: int = tree.tenant_char_count["tenant_1"] + total_size: int = tree.tenant_to_char_count["tenant_1"] evicted_count = tree.evict_tenant_by_lru("tenant_1", total_size) assert evicted_count == total_size - # "tenant_1" should still be in tenants - assert "tenant_1" in tree.tenants + # "tenant_1" should still be in tenant_to_nodes + assert "tenant_1" in tree.tenant_to_nodes - # 6. Test tree structure and LRU heap ordering + # 6. Test tree structure and LRU eviction tree._reset() # Insert strings in specified order @@ -464,71 +473,81 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: # Test tree structure - validate each node # Root node assert root.text == "" - assert root.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} - assert "h" in root.children + assert root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "h" in root.edge_label_to_child # Hello node - hello_node: Node = root.children["h"] + hello_node: Node = root.edge_label_to_child["h"] assert hello_node.text == "hello" - assert hello_node.tenant_last_access_time == {"tenant_1": 1, "tenant_2": 3} - assert "w" in hello_node.children - assert "t" in hello_node.children + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "w" in hello_node.edge_label_to_child + assert "t" in hello_node.edge_label_to_child # World node - world_node: Node = hello_node.children["w"] + world_node: Node = hello_node.edge_label_to_child["w"] assert world_node.text == "world" - assert world_node.tenant_last_access_time == {"tenant_1": 1} - assert len(world_node.children) == 0 + assert world_node.tenant_to_last_access_time == {"tenant_1": 1} + assert len(world_node.edge_label_to_child) == 0 # Th node - th_node: Node = hello_node.children["t"] + th_node: Node = hello_node.edge_label_to_child["t"] assert th_node.text == "th" - assert th_node.tenant_last_access_time == {"tenant_2": 3} - assert "e" in th_node.children - assert "o" in th_node.children + assert th_node.tenant_to_last_access_time == {"tenant_2": 3} + assert "e" in th_node.edge_label_to_child + assert "o" in th_node.edge_label_to_child # Ere node - ere_node: Node = th_node.children["e"] + ere_node: Node = th_node.edge_label_to_child["e"] assert ere_node.text == "ere" - assert ere_node.tenant_last_access_time == {"tenant_2": 2} - assert len(ere_node.children) == 0 + assert ere_node.tenant_to_last_access_time == {"tenant_2": 2} + assert len(ere_node.edge_label_to_child) == 0 # Omas node - omas_node: Node = th_node.children["o"] + omas_node: Node = th_node.edge_label_to_child["o"] assert omas_node.text == "omas" - assert omas_node.tenant_last_access_time == {"tenant_2": 3} - assert len(omas_node.children) == 0 + assert omas_node.tenant_to_last_access_time == {"tenant_2": 3} + assert len(omas_node.edge_label_to_child) == 0 # Test PrefixTree instance variables - assert tree.tenants == {"tenant_1", "tenant_2"} - - # Test tenant_char_count - assert ( - tree.tenant_char_count["tenant_1"] == 10 - ) # root(0) + hello(5) + world(5) = 10 - assert ( - tree.tenant_char_count["tenant_2"] == 14 - ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 - - # Test tenant_nodes (check by text) + assert set(tree.tenant_to_nodes.keys()) == {"tenant_1", "tenant_2"} + + # Test tenant_to_nodes (check by text) tenant1_nodes_texts: Set[str] = { - node.text for node in tree.tenant_nodes["tenant_1"] + node.text for node in tree.tenant_to_nodes["tenant_1"] } assert tenant1_nodes_texts == {"", "hello", "world"} tenant2_nodes_texts: Set[str] = { - node.text for node in tree.tenant_nodes["tenant_2"] + node.text for node in tree.tenant_to_nodes["tenant_2"] } assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} - # Test tenant_nodes_sorted - validate heap ordering - tenant1_heap: List[TenantHeapNode] = tree.tenant_nodes_sorted["tenant_1"] - tenant2_heap: List[TenantHeapNode] = tree.tenant_nodes_sorted["tenant_2"] + # Test tenant_to_char_count + ## Before evictions + assert ( + tree.tenant_to_char_count["tenant_1"] == 10 + ) # root(0) + hello(5) + world(5) = 10 + assert ( + tree.tenant_to_char_count["tenant_2"] == 14 + ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 - assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 - assert heapq.heappop(tenant1_heap).node.tenant_last_access_time["tenant_1"] == 1 - assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 2 - assert heapq.heappop(tenant2_heap).node.tenant_last_access_time["tenant_2"] == 3 + ## After evicting tenant_1 with min_remove_size=1 + # Should remove both "hello" and "world" nodes (10 chars) since they have the same timestamp + evicted_count = tree.evict_tenant_by_lru("tenant_1", 1) + assert evicted_count == 10 # All 10 chars removed, not just 1 + assert tree.tenant_to_char_count["tenant_1"] == 0 + + ## After evicting tenant_2 with min_remove_size=1 + # Should remove "ere" node (3 chars) since it has the oldest timestamp (2) + evicted_count = tree.evict_tenant_by_lru("tenant_2", 1) + assert evicted_count == 3 # All 3 chars from "ere" removed + assert tree.tenant_to_char_count["tenant_2"] == 11 # 14 - 3 = 11 + + ## After evicting tenant_2 again with min_remove_size=1 + # Should remove "hello", "th", and "omas" nodes (11 chars) since they all have timestamp 3 + evicted_count = tree.evict_tenant_by_lru("tenant_2", 1) + assert evicted_count == 11 # All 11 remaining chars removed + assert tree.tenant_to_char_count["tenant_2"] == 0 def test_get_smallest_tenant(tree: PrefixTree) -> None: From a0db565531d72bae0aeb9636a038f9fe726c4cb3 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Wed, 7 May 2025 13:36:26 -0700 Subject: [PATCH 11/15] Update tests Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 49 +++--- .../serve/cpu/deployments/test_prefix_tree.py | 165 ++++++------------ 2 files changed, 77 insertions(+), 137 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index 544686164410f..fbf53df057a18 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -41,37 +41,34 @@ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: {} ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in milliseconds) -class TenantHeapNode: +class TimestampedNode: """ - Wrapper class for storing nodes in a min-heap, ordered by tenant access time. - Used for efficient LRU eviction of tenant nodes. + Wrapper class for storing nodes in a min-heap, ordered by timestamp. + Used for efficient LRU eviction of nodes. """ - def __init__(self, node: Node, tenant_ordering_key: str) -> None: + def __init__(self, node: Node, time_sec: float) -> None: """ - Initialize a heap node for efficient LRU tenant management. + Initialize a heap node for efficient LRU eviction of nodes. Args: node: The prefix tree node this heap node refers to - tenant_ordering_key: The tenant this heap uses to order nodes + time_sec: The timestamp this heap uses to order nodes """ self.node = node - self.tenant_ordering_key = tenant_ordering_key + self.time_sec = time_sec - def __lt__(self, other: TenantHeapNode) -> bool: + def __lt__(self, other: TimestampedNode) -> bool: """ - Compare heap nodes based on tenant's last access time. + Compare heap nodes based on timestamp. Args: - other: Another TenantHeapNode to compare with + other: Another TimestampedNode to compare with Returns: - True if this node's tenant access time is earlier than the other's + True if this node's timestamp is earlier than the other's """ - return ( - self.node.tenant_to_last_access_time[self.tenant_ordering_key] - < other.node.tenant_to_last_access_time[other.tenant_ordering_key] - ) + return self.time_sec < other.time_sec class PrefixTree: @@ -177,11 +174,6 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: tenant: Tenant to remove node: Node to remove tenant from - Does: - Decrements self.tenant_to_char_count[tenant] by the length of the node's text. - Removes the tenant from node.tenant_to_last_access_time. - Removes the node from self.tenant_to_nodes[tenant]. - Returns: Number of characters removed (0 if preconditions not met) """ @@ -328,6 +320,9 @@ def prefix_match( Returns: Tuple of (matched_text, matched_tenants) + - If the list of available tenants doesn't match any tenants in the tree: returns ("", None) + - When no prefix match is found (does not traverse further than the root node): returns ("", list of available tenants) + - When a prefix match is found: returns (matched_prefix, list of tenants that own the matched node) """ if available_tenants: # Filter available_tenants to only include those in the tree @@ -390,7 +385,7 @@ def remove_tenant(self, tenant: str) -> int: tenant: Tenant to remove Returns: - Number of characters removed + Number of characters removed (0 if tenant doesn't exist) """ with self.lock: if tenant not in self.tenant_to_nodes: @@ -415,7 +410,7 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: min_remove_size: Minimum number of characters to remove Returns: - Actual number of characters removed + Actual number of characters removed (0 if tenant doesn't exist) Note: - All nodes with the same oldest access time are removed together to maintain tree integrity, even if only removing a subset of them satisfies the min_remove_size. @@ -442,11 +437,11 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: total_chars_removed: int = 0 # Create a min-heap of nodes ordered by access time - # Each entry is a tuple of (access_time, node) so heapq sorts by access_time first + # Each entry is a TimestampedNode(node, access_time) object, which has a __lt__ method that is used by heapq. nodes_by_access_time = [] for node in self.tenant_to_nodes[tenant]: access_time = node.tenant_to_last_access_time[tenant] - nodes_by_access_time.append((access_time, node)) + nodes_by_access_time.append(TimestampedNode(node, access_time)) heapq.heapify(nodes_by_access_time) # Remove nodes until we've freed enough characters @@ -455,15 +450,15 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: and nodes_by_access_time ): # Get the oldest (minimum) access time from the top of the heap - oldest_access_time = nodes_by_access_time[0][0] + oldest_access_time = nodes_by_access_time[0].time_sec # Remove ALL nodes with this same access time to maintain tree consistency # (partial removals could break prefix relationships) while ( nodes_by_access_time - and nodes_by_access_time[0][0] == oldest_access_time + and nodes_by_access_time[0].time_sec == oldest_access_time ): - _, node_to_remove = heapq.heappop(nodes_by_access_time) + node_to_remove = heapq.heappop(nodes_by_access_time).node total_chars_removed += self._remove_tenant_single_node( tenant, node_to_remove ) diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index 235f4c3caadb5..1598fd9173fa5 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -1,7 +1,6 @@ import pytest import ray from ray import serve -import heapq from typing import Set, List, Dict, Optional, Generator from ray.llm._internal.serve.replica_scheduler.prefix_aware.prefix_tree import ( @@ -111,33 +110,31 @@ async def test_tree_deployment(tree_deployment) -> None: assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} # Test tenant_to_char_count - ## Before evictions - assert ( - tree_rep["tenant_to_char_count"]["tenant_1"] == 10 - ) # root(0) + hello(5) + world(5) = 10 - assert ( - tree_rep["tenant_to_char_count"]["tenant_2"] == 14 - ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 - - ## After evicting tenant_1 with min_remove_size=1 + # Before evictions + assert tree_rep["tenant_to_char_count"]["tenant_1"] == 10 # root(0) + hello(5) + world(5) = 10 + assert tree_rep["tenant_to_char_count"]["tenant_2"] == 14 # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 + + # After evicting tenant_1 with min_remove_size=1 # Should remove both "hello" and "world" nodes (10 chars) since they have the same timestamp evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_1", 1) assert evicted_count == 10 # All 10 chars removed, not just 1 - tree_rep: Dict = await tree_deployment._to_dict.remote() + tree_rep = await tree_deployment._to_dict.remote() assert tree_rep["tenant_to_char_count"]["tenant_1"] == 0 - ## After evicting tenant_2 with min_remove_size=1 + # After evicting tenant_2 with min_remove_size=1 # Should remove "ere" node (3 chars) since it has the oldest timestamp (2) evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_2", 1) assert evicted_count == 3 # All 3 chars from "ere" removed - tree_rep: Dict = await tree_deployment._to_dict.remote() + + tree_rep = await tree_deployment._to_dict.remote() assert tree_rep["tenant_to_char_count"]["tenant_2"] == 11 # 14 - 3 = 11 - ## After evicting tenant_2 again with min_remove_size=1 + # After evicting tenant_2 again with min_remove_size=1 # Should remove "hello", "th", and "omas" nodes (11 chars) since they all have timestamp 3 evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_2", 1) assert evicted_count == 11 # All 11 remaining chars removed - tree_rep: Dict = await tree_deployment._to_dict.remote() + + tree_rep = await tree_deployment._to_dict.remote() assert tree_rep["tenant_to_char_count"]["tenant_2"] == 0 # PrefixTree tests @@ -166,9 +163,7 @@ def test_insert(tree: PrefixTree) -> None: # No need to call add_tenant first - insert will do it automatically tree.insert("hello", "tenant_1", 1) matched_text, matched_tenants = tree.prefix_match("hello") - assert matched_text == "hello" - assert matched_tenants == ["tenant_1"] - + assert matched_text == "hello" and matched_tenants == ["tenant_1"] assert tree.tenant_to_char_count["tenant_1"] == 5 assert len(tree.tenant_to_nodes["tenant_1"]) == 2 @@ -177,9 +172,7 @@ def test_insert(tree: PrefixTree) -> None: tree.insert("foo", "tenant_1", 1) tree.insert("foo", "tenant_1", 1) # duplicate tree.insert("bar", "tenant_2", 2) - - assert tree.tenant_to_char_count["tenant_1"] == 3 - assert tree.tenant_to_char_count["tenant_2"] == 3 + assert tree.tenant_to_char_count["tenant_1"] == 3 and tree.tenant_to_char_count["tenant_2"] == 3 # 3. Test node splitting on partial match tree._reset() @@ -188,8 +181,7 @@ def test_insert(tree: PrefixTree) -> None: root: Node = tree.root h_node: Optional[Node] = root.edge_label_to_child.get("h") - assert h_node is not None - assert h_node.text == "hello" + assert h_node is not None and h_node.text == "hello" assert h_node.edge_label_to_child.get("w").text == "world" assert h_node.edge_label_to_child.get("t").text == "there" @@ -213,22 +205,17 @@ def test_insert(tree: PrefixTree) -> None: nodes_to_check.extend(node.edge_label_to_child.values()) # There should be exactly one empty text node (the root) - assert len(empty_text_nodes) == 1 - assert root in empty_text_nodes + assert len(empty_text_nodes) == 1 and root in empty_text_nodes # Verify tree structure h_node = root.edge_label_to_child.get("h") - assert h_node is not None - assert h_node.text == "hello" - assert "tenant_1" in h_node.tenant_to_last_access_time - assert "tenant_2" in h_node.tenant_to_last_access_time + assert h_node is not None and h_node.text == "hello" + assert "tenant_1" in h_node.tenant_to_last_access_time and "tenant_2" in h_node.tenant_to_last_access_time # Verify "world" node belongs only to tenant 2 world_node: Optional[Node] = h_node.edge_label_to_child.get("w") - assert world_node is not None - assert world_node.text == "world" - assert "tenant_2" in world_node.tenant_to_last_access_time - assert "tenant_1" not in world_node.tenant_to_last_access_time + assert world_node is not None and world_node.text == "world" + assert "tenant_2" in world_node.tenant_to_last_access_time and "tenant_1" not in world_node.tenant_to_last_access_time # Verify the only child of h_node is "w" assert len(h_node.edge_label_to_child) == 1 @@ -239,47 +226,41 @@ def test_prefix_match(tree: PrefixTree) -> None: # 1. Test no match tree._reset() matched_text, matched_tenants = tree.prefix_match("hello") - assert matched_text == "" - assert matched_tenants is None + assert matched_text == "" and matched_tenants is None # 2. Test match with non-existing prefix returns empty string and all tenants tree._reset() tree.insert("hello", "tenant_1", 1) tree.insert("hellothere", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("foobar") - assert matched_text == "" - assert matched_tenants == ["tenant_1", "tenant_2"] + assert matched_text == "" and matched_tenants == ["tenant_1", "tenant_2"] # 3. Test exact match tree._reset() tree.insert("hello", "tenant_1", 1) matched_text, matched_tenants = tree.prefix_match("hello") - assert matched_text == "hello" - assert matched_tenants == ["tenant_1"] + assert matched_text == "hello" and matched_tenants == ["tenant_1"] # 4. Test partial match tree._reset() tree.insert("apple", "tenant_1", 1) tree.insert("apricot", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("application") - assert matched_text == "appl" - assert matched_tenants == ["tenant_1"] + assert matched_text == "appl" and matched_tenants == ["tenant_1"] # 5. Test match by tenant tree._reset() tree.insert("apple", "tenant_1", 1) tree.insert("apricot", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("application", ["tenant_2"]) - assert matched_text == "ap" - assert matched_tenants == ["tenant_2"] + assert matched_text == "ap" and matched_tenants == ["tenant_2"] # 6. Test match by non-existent tenant tree._reset() tree.insert("apple", "tenant_1", 1) tree.insert("apricot", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("application", ["tenant_3"]) - assert matched_text == "" - assert matched_tenants is None + assert matched_text == "" and matched_tenants is None # 7. Test shared prefix matching with branches tree._reset() @@ -287,28 +268,23 @@ def test_prefix_match(tree: PrefixTree) -> None: tree.insert("hellothere", "tenant_2", 2) matched_text, matched_tenants = tree.prefix_match("helloworld") - assert matched_text == "helloworld" - assert matched_tenants == ["tenant_1"] + assert matched_text == "helloworld" and matched_tenants == ["tenant_1"] matched_text, matched_tenants = tree.prefix_match("hellothereworld") - assert matched_text == "hellothere" - assert matched_tenants == ["tenant_2"] + assert matched_text == "hellothere" and matched_tenants == ["tenant_2"] def test__remove_tenant_single_node(tree: PrefixTree) -> None: """Test removing a single node for a tenant.""" # 1. Test removing a single node - tree._reset() tree.insert("hello", "tenant_1", 1) h_node: Node = tree.insert("hello", "tenant_1", 1) removed: int = tree._remove_tenant_single_node("tenant_1", h_node) assert removed == 5 - assert tree.tenant_to_char_count["tenant_1"] == 0 - assert len(tree.tenant_to_nodes["tenant_1"]) == 1 - assert tree.root in tree.tenant_to_nodes["tenant_1"] + assert len(tree.tenant_to_nodes["tenant_1"]) == 1 and tree.root in tree.tenant_to_nodes["tenant_1"] # 2. Test removing node for non-existent tenant is idempotent tree._reset() @@ -340,9 +316,7 @@ def test_remove_tenant(tree: PrefixTree) -> None: tree.insert("hello", "tenant_1", 1) removed: int = tree.remove_tenant("tenant_1") assert removed == 5 - - assert "tenant_1" not in tree.tenant_to_nodes - assert "tenant_1" not in tree.tenant_to_char_count + assert "tenant_1" not in tree.tenant_to_nodes and "tenant_1" not in tree.tenant_to_char_count # 2. Test removing tenant with multiple nodes tree._reset() @@ -364,13 +338,10 @@ def test_remove_tenant(tree: PrefixTree) -> None: # Remove tenant_1, verify tenant_2 still works tree.remove_tenant("tenant_1") - - assert "tenant_1" not in tree.tenant_to_nodes - assert "tenant_2" in tree.tenant_to_nodes + assert "tenant_1" not in tree.tenant_to_nodes and "tenant_2" in tree.tenant_to_nodes matched_text, matched_tenants = tree.prefix_match("hello") - assert matched_text == "hello" - assert matched_tenants == ["tenant_2"] + assert matched_text == "hello" and matched_tenants == ["tenant_2"] # 5. Test removing the last tenant from a node removes the node tree._reset() @@ -398,8 +369,7 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: # Before eviction char_count_before: int = tree.tenant_to_char_count["tenant_1"] - assert len(tree.tenant_to_nodes["tenant_1"]) == 4 - assert tree.tenant_to_char_count["tenant_1"] == 6 + assert len(tree.tenant_to_nodes["tenant_1"]) == 4 and tree.tenant_to_char_count["tenant_1"] == 6 # During eviction min_remove_size: int = 1 @@ -409,8 +379,7 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: char_count_after: int = tree.tenant_to_char_count["tenant_1"] assert evicted_count == min_remove_size assert char_count_before - char_count_after == evicted_count - assert len(tree.tenant_to_nodes["tenant_1"]) == 3 - assert tree.tenant_to_char_count["tenant_1"] == 5 + assert len(tree.tenant_to_nodes["tenant_1"]) == 3 and tree.tenant_to_char_count["tenant_1"] == 5 # 2. Remove more than min_remove_size characters tree._reset() @@ -420,8 +389,7 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: # Before eviction char_count_before = tree.tenant_to_char_count["tenant_1"] - assert len(tree.tenant_to_nodes["tenant_1"]) == 4 - assert tree.tenant_to_char_count["tenant_1"] == 6 + assert len(tree.tenant_to_nodes["tenant_1"]) == 4 and tree.tenant_to_char_count["tenant_1"] == 6 # During eviction min_remove_size = 2 @@ -431,8 +399,7 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: char_count_after = tree.tenant_to_char_count["tenant_1"] assert evicted_count != min_remove_size and evicted_count == 3 assert char_count_before - char_count_after == evicted_count - assert len(tree.tenant_to_nodes["tenant_1"]) == 2 - assert tree.tenant_to_char_count["tenant_1"] == 3 + assert len(tree.tenant_to_nodes["tenant_1"]) == 2 and tree.tenant_to_char_count["tenant_1"] == 3 # 3. Test eviction of non-existent tenant is idempotent tree._reset() @@ -452,10 +419,8 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: tree.insert("xyz", "tenant_1", 1) total_size: int = tree.tenant_to_char_count["tenant_1"] - evicted_count = tree.evict_tenant_by_lru("tenant_1", total_size) assert evicted_count == total_size - # "tenant_1" should still be in tenant_to_nodes assert "tenant_1" in tree.tenant_to_nodes @@ -472,82 +437,62 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: # Test tree structure - validate each node # Root node - assert root.text == "" - assert root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert root.text == "" and root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} assert "h" in root.edge_label_to_child # Hello node hello_node: Node = root.edge_label_to_child["h"] - assert hello_node.text == "hello" - assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} - assert "w" in hello_node.edge_label_to_child - assert "t" in hello_node.edge_label_to_child + assert hello_node.text == "hello" and hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert "w" in hello_node.edge_label_to_child and "t" in hello_node.edge_label_to_child # World node world_node: Node = hello_node.edge_label_to_child["w"] - assert world_node.text == "world" - assert world_node.tenant_to_last_access_time == {"tenant_1": 1} + assert world_node.text == "world" and world_node.tenant_to_last_access_time == {"tenant_1": 1} assert len(world_node.edge_label_to_child) == 0 # Th node th_node: Node = hello_node.edge_label_to_child["t"] - assert th_node.text == "th" - assert th_node.tenant_to_last_access_time == {"tenant_2": 3} - assert "e" in th_node.edge_label_to_child - assert "o" in th_node.edge_label_to_child + assert th_node.text == "th" and th_node.tenant_to_last_access_time == {"tenant_2": 3} + assert "e" in th_node.edge_label_to_child and "o" in th_node.edge_label_to_child # Ere node ere_node: Node = th_node.edge_label_to_child["e"] - assert ere_node.text == "ere" - assert ere_node.tenant_to_last_access_time == {"tenant_2": 2} + assert ere_node.text == "ere" and ere_node.tenant_to_last_access_time == {"tenant_2": 2} assert len(ere_node.edge_label_to_child) == 0 # Omas node omas_node: Node = th_node.edge_label_to_child["o"] - assert omas_node.text == "omas" - assert omas_node.tenant_to_last_access_time == {"tenant_2": 3} + assert omas_node.text == "omas" and omas_node.tenant_to_last_access_time == {"tenant_2": 3} assert len(omas_node.edge_label_to_child) == 0 # Test PrefixTree instance variables assert set(tree.tenant_to_nodes.keys()) == {"tenant_1", "tenant_2"} # Test tenant_to_nodes (check by text) - tenant1_nodes_texts: Set[str] = { - node.text for node in tree.tenant_to_nodes["tenant_1"] - } + tenant1_nodes_texts: Set[str] = {node.text for node in tree.tenant_to_nodes["tenant_1"]} assert tenant1_nodes_texts == {"", "hello", "world"} - tenant2_nodes_texts: Set[str] = { - node.text for node in tree.tenant_to_nodes["tenant_2"] - } + tenant2_nodes_texts: Set[str] = {node.text for node in tree.tenant_to_nodes["tenant_2"]} assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} # Test tenant_to_char_count - ## Before evictions - assert ( - tree.tenant_to_char_count["tenant_1"] == 10 - ) # root(0) + hello(5) + world(5) = 10 - assert ( - tree.tenant_to_char_count["tenant_2"] == 14 - ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 - - ## After evicting tenant_1 with min_remove_size=1 + # Before evictions + assert tree.tenant_to_char_count["tenant_1"] == 10 and tree.tenant_to_char_count["tenant_2"] == 14 + + # After evicting tenant_1 with min_remove_size=1 # Should remove both "hello" and "world" nodes (10 chars) since they have the same timestamp evicted_count = tree.evict_tenant_by_lru("tenant_1", 1) - assert evicted_count == 10 # All 10 chars removed, not just 1 - assert tree.tenant_to_char_count["tenant_1"] == 0 + assert evicted_count == 10 and tree.tenant_to_char_count["tenant_1"] == 0 - ## After evicting tenant_2 with min_remove_size=1 + # After evicting tenant_2 with min_remove_size=1 # Should remove "ere" node (3 chars) since it has the oldest timestamp (2) evicted_count = tree.evict_tenant_by_lru("tenant_2", 1) - assert evicted_count == 3 # All 3 chars from "ere" removed - assert tree.tenant_to_char_count["tenant_2"] == 11 # 14 - 3 = 11 + assert evicted_count == 3 and tree.tenant_to_char_count["tenant_2"] == 11 # 14 - 3 = 11 - ## After evicting tenant_2 again with min_remove_size=1 + # After evicting tenant_2 again with min_remove_size=1 # Should remove "hello", "th", and "omas" nodes (11 chars) since they all have timestamp 3 evicted_count = tree.evict_tenant_by_lru("tenant_2", 1) - assert evicted_count == 11 # All 11 remaining chars removed - assert tree.tenant_to_char_count["tenant_2"] == 0 + assert evicted_count == 11 and tree.tenant_to_char_count["tenant_2"] == 0 def test_get_smallest_tenant(tree: PrefixTree) -> None: From 9fa20bb1cbb196535b9ed0dd373b28cc69297ee4 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Wed, 7 May 2025 14:21:23 -0700 Subject: [PATCH 12/15] Add PrefixTreeActor Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 51 ++--- .../serve/cpu/deployments/test_prefix_tree.py | 176 +++++++++++------- 2 files changed, 138 insertions(+), 89 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index fbf53df057a18..9a026dfebe25e 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -6,7 +6,7 @@ from threading import RLock from typing import Dict, List, Optional, Set, Tuple, Any -from ray import serve +import ray logger = logging.getLogger(__name__) @@ -34,13 +34,16 @@ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: """ self.text: str = text self.parent: Optional[Node] = parent # The parent node of this node - self.edge_label_to_child: Dict[str, Node] = {} # Maps first character to child node + self.edge_label_to_child: Dict[ + str, Node + ] = {} # Maps first character to child node self.tenant_to_last_access_time: Dict[ str, int ] = ( {} ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in milliseconds) + class TimestampedNode: """ Wrapper class for storing nodes in a min-heap, ordered by timestamp. @@ -111,10 +114,14 @@ def __init__(self) -> None: self.root: Node = Node() self.tenant_to_char_count: Dict[ str, int - ] = {} # Tracks total character count per tenant. Used by the client to determine which tenant to evict, and by how much. + ] = ( + {} + ) # Tracks total character count per tenant. Used by the client to determine which tenant to evict, and by how much. self.tenant_to_nodes: Dict[ str, Set[Node] - ] = {} # Maps tenant to set of nodes. Used for O(1) testing if a node belongs to a tenant. The keys are the active tenants in the tree. + ] = ( + {} + ) # Maps tenant to set of nodes. Used for O(1) testing if a node belongs to a tenant. The keys are the active tenants in the tree. @staticmethod def _shared_prefix_count(a: str, b: str) -> int: @@ -166,7 +173,7 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: - tenant exists in self.tenant_to_nodes - tenant exists in node.tenant_to_last_access_time - node exists in self.tenant_to_nodes[tenant] - + These preconditions are guaranteed to be satisfied if the user is using the public methods of this class. They may be violated if the user manipulates the internal state of the tree directly. @@ -182,10 +189,14 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") return 0 if tenant not in node.tenant_to_last_access_time: - logger.warning(f"Tenant '{tenant}' does not have node '{node.text}'. No action taken.") + logger.warning( + f"Tenant '{tenant}' does not have node '{node.text}'. No action taken." + ) return 0 if node not in self.tenant_to_nodes[tenant]: - logger.warning(f"Node '{node.text}' does not belong to tenant '{tenant}'. No action taken.") + logger.warning( + f"Node '{node.text}' does not belong to tenant '{tenant}'. No action taken." + ) return 0 removed_chars_len: int = len(node.text) @@ -251,7 +262,7 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: curr_text: str = text[i:] if first_char not in curr_node.edge_label_to_child: - # No match, create new node. Don't update new node as "visited" by tenant yet; it will be done in the code below. + # No match, create new node. Don't update new node as "visited" by tenant yet; it will be done at the beginning of the next iteration. # e.g. curr_node.edge_label_to_child = {}, curr_text = "hello" -> curr_node.edge_label_to_child = {"h": Node("hello")} new_node: Node = Node(text=curr_text, parent=curr_node) curr_node.edge_label_to_child[first_char] = new_node @@ -269,13 +280,10 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: ### curr_node.edge_label_to_child = {"h": Node("helloworld")}, curr_text = "hellothere" -> shared_count = 5 ### matched_node = Node("helloworld") - ## During update: - ### Increment tenant_to_char_count[tenant] by shared_count if matched_node has not seen this tenant before - ## After update: ### curr_node.edge_label_to_child = {"h": Node("hello", edge_label_to_child = {"w": Node("world")})} ### parent_node = Node("hello"), matched_node = Node("world") - ### Update tenant_to_last_access_time for parent_node, NOT matched_node + ### Copy matched_node.tenant_to_last_access_time to parent_node.tenant_to_last_access_time ### (new) curr_text = "there", (new) curr_node = parent_node ### Continue adding "there" to tree in next iteration @@ -411,7 +419,7 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: Returns: Actual number of characters removed (0 if tenant doesn't exist) - + Note: - All nodes with the same oldest access time are removed together to maintain tree integrity, even if only removing a subset of them satisfies the min_remove_size. - This behavior is expected in the case when an input was split into multiple nodes by a different tenant (e.g. insert("helloworld", "tenant_1", 1) and insert("hellothere", "tenant_2", 2)). @@ -443,12 +451,9 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: access_time = node.tenant_to_last_access_time[tenant] nodes_by_access_time.append(TimestampedNode(node, access_time)) heapq.heapify(nodes_by_access_time) - + # Remove nodes until we've freed enough characters - while ( - total_chars_removed < min_remove_size - and nodes_by_access_time - ): + while total_chars_removed < min_remove_size and nodes_by_access_time: # Get the oldest (minimum) access time from the top of the heap oldest_access_time = nodes_by_access_time[0].time_sec @@ -476,11 +481,15 @@ def get_smallest_tenant(self) -> Optional[str]: if not self.tenant_to_char_count: return None - return min(self.tenant_to_char_count, key=self.tenant_to_char_count.get, default=None) + return min( + self.tenant_to_char_count, + key=self.tenant_to_char_count.get, + default=None, + ) -@serve.deployment(name="TreeDeployment") -class PrefixTreeDeployment(PrefixTree): +@ray.remote +class PrefixTreeActor(PrefixTree): def _to_dict(self) -> Dict[str, Any]: """ Convert tree to dictionary for serialization. diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index 1598fd9173fa5..0e39859fa85c8 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -1,11 +1,10 @@ import pytest import ray -from ray import serve -from typing import Set, List, Dict, Optional, Generator +from typing import Set, List, Dict, Optional from ray.llm._internal.serve.replica_scheduler.prefix_aware.prefix_tree import ( PrefixTree, - PrefixTreeDeployment, + PrefixTreeActor, Node, ) @@ -17,43 +16,27 @@ def tree() -> PrefixTree: return PrefixTree() -@pytest.fixture(scope="module", autouse=True) -def serve_instance() -> Generator[None, None, None]: - # Start Ray and Serve once per test module - ray.init(ignore_reinit_error=True) - serve.start(detached=True) - yield - serve.shutdown() - ray.shutdown() - - @pytest.fixture(scope="module") -def tree_deployment(): - """Create a fresh PrefixTreeDeployment instance for each test.""" - tree = serve.run(PrefixTreeDeployment.bind()) - return tree +def tree_actor(): + """Create a fresh PrefixTreeActor instance for each test.""" + tree_actor = PrefixTreeActor.remote() + return tree_actor -# PrefixTreeDeployment tests +# PrefixTreeActor tests @pytest.mark.asyncio -async def test_tree_deployment(tree_deployment) -> None: - """Test the PrefixTreeDeployment.""" - # 6. Test tree structure and LRU heap ordering - await tree_deployment._reset.remote() +async def test_tree_actor(tree_actor) -> None: + """Test the PrefixTreeActor.""" + # 1. Test tree structure and LRU heap ordering + tree_actor._reset.remote() # Insert strings in specified order - await tree_deployment.insert.remote( - "helloworld", "tenant_1", 1 - ) # time 1 for tenant_1 - await tree_deployment.insert.remote( - "hellothere", "tenant_2", 2 - ) # time 2 for tenant_2 - await tree_deployment.insert.remote( - "hellothomas", "tenant_2", 3 - ) # time 3 for tenant_2 + tree_actor.insert.remote("helloworld", "tenant_1", 1) # time 1 for tenant_1 + tree_actor.insert.remote("hellothere", "tenant_2", 2) # time 2 for tenant_2 + tree_actor.insert.remote("hellothomas", "tenant_2", 3) # time 3 for tenant_2 # Access tree directly - tree_rep: Dict = await tree_deployment._to_dict.remote() + tree_rep: Dict = ray.get(tree_actor._to_dict.remote()) root: Node = tree_rep["root"] # Test tree structure - validate each node @@ -111,32 +94,37 @@ async def test_tree_deployment(tree_deployment) -> None: # Test tenant_to_char_count # Before evictions - assert tree_rep["tenant_to_char_count"]["tenant_1"] == 10 # root(0) + hello(5) + world(5) = 10 - assert tree_rep["tenant_to_char_count"]["tenant_2"] == 14 # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 + assert ( + tree_rep["tenant_to_char_count"]["tenant_1"] == 10 + ) # root(0) + hello(5) + world(5) = 10 + assert ( + tree_rep["tenant_to_char_count"]["tenant_2"] == 14 + ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 # After evicting tenant_1 with min_remove_size=1 # Should remove both "hello" and "world" nodes (10 chars) since they have the same timestamp - evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_1", 1) + evicted_count = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_1", 1)) assert evicted_count == 10 # All 10 chars removed, not just 1 - tree_rep = await tree_deployment._to_dict.remote() + tree_rep = ray.get(tree_actor._to_dict.remote()) assert tree_rep["tenant_to_char_count"]["tenant_1"] == 0 - + # After evicting tenant_2 with min_remove_size=1 # Should remove "ere" node (3 chars) since it has the oldest timestamp (2) - evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_2", 1) + evicted_count = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_2", 1)) assert evicted_count == 3 # All 3 chars from "ere" removed - - tree_rep = await tree_deployment._to_dict.remote() + + tree_rep = ray.get(tree_actor._to_dict.remote()) assert tree_rep["tenant_to_char_count"]["tenant_2"] == 11 # 14 - 3 = 11 - + # After evicting tenant_2 again with min_remove_size=1 # Should remove "hello", "th", and "omas" nodes (11 chars) since they all have timestamp 3 - evicted_count = await tree_deployment.evict_tenant_by_lru.remote("tenant_2", 1) + evicted_count = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_2", 1)) assert evicted_count == 11 # All 11 remaining chars removed - - tree_rep = await tree_deployment._to_dict.remote() + + tree_rep = ray.get(tree_actor._to_dict.remote()) assert tree_rep["tenant_to_char_count"]["tenant_2"] == 0 + # PrefixTree tests def test__add_tenant(tree: PrefixTree) -> None: """Test adding tenants to the tree via the private _add_tenant method.""" @@ -172,7 +160,10 @@ def test_insert(tree: PrefixTree) -> None: tree.insert("foo", "tenant_1", 1) tree.insert("foo", "tenant_1", 1) # duplicate tree.insert("bar", "tenant_2", 2) - assert tree.tenant_to_char_count["tenant_1"] == 3 and tree.tenant_to_char_count["tenant_2"] == 3 + assert ( + tree.tenant_to_char_count["tenant_1"] == 3 + and tree.tenant_to_char_count["tenant_2"] == 3 + ) # 3. Test node splitting on partial match tree._reset() @@ -210,12 +201,18 @@ def test_insert(tree: PrefixTree) -> None: # Verify tree structure h_node = root.edge_label_to_child.get("h") assert h_node is not None and h_node.text == "hello" - assert "tenant_1" in h_node.tenant_to_last_access_time and "tenant_2" in h_node.tenant_to_last_access_time + assert ( + "tenant_1" in h_node.tenant_to_last_access_time + and "tenant_2" in h_node.tenant_to_last_access_time + ) # Verify "world" node belongs only to tenant 2 world_node: Optional[Node] = h_node.edge_label_to_child.get("w") assert world_node is not None and world_node.text == "world" - assert "tenant_2" in world_node.tenant_to_last_access_time and "tenant_1" not in world_node.tenant_to_last_access_time + assert ( + "tenant_2" in world_node.tenant_to_last_access_time + and "tenant_1" not in world_node.tenant_to_last_access_time + ) # Verify the only child of h_node is "w" assert len(h_node.edge_label_to_child) == 1 @@ -284,7 +281,10 @@ def test__remove_tenant_single_node(tree: PrefixTree) -> None: removed: int = tree._remove_tenant_single_node("tenant_1", h_node) assert removed == 5 assert tree.tenant_to_char_count["tenant_1"] == 0 - assert len(tree.tenant_to_nodes["tenant_1"]) == 1 and tree.root in tree.tenant_to_nodes["tenant_1"] + assert ( + len(tree.tenant_to_nodes["tenant_1"]) == 1 + and tree.root in tree.tenant_to_nodes["tenant_1"] + ) # 2. Test removing node for non-existent tenant is idempotent tree._reset() @@ -316,7 +316,10 @@ def test_remove_tenant(tree: PrefixTree) -> None: tree.insert("hello", "tenant_1", 1) removed: int = tree.remove_tenant("tenant_1") assert removed == 5 - assert "tenant_1" not in tree.tenant_to_nodes and "tenant_1" not in tree.tenant_to_char_count + assert ( + "tenant_1" not in tree.tenant_to_nodes + and "tenant_1" not in tree.tenant_to_char_count + ) # 2. Test removing tenant with multiple nodes tree._reset() @@ -369,7 +372,10 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: # Before eviction char_count_before: int = tree.tenant_to_char_count["tenant_1"] - assert len(tree.tenant_to_nodes["tenant_1"]) == 4 and tree.tenant_to_char_count["tenant_1"] == 6 + assert ( + len(tree.tenant_to_nodes["tenant_1"]) == 4 + and tree.tenant_to_char_count["tenant_1"] == 6 + ) # During eviction min_remove_size: int = 1 @@ -379,7 +385,10 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: char_count_after: int = tree.tenant_to_char_count["tenant_1"] assert evicted_count == min_remove_size assert char_count_before - char_count_after == evicted_count - assert len(tree.tenant_to_nodes["tenant_1"]) == 3 and tree.tenant_to_char_count["tenant_1"] == 5 + assert ( + len(tree.tenant_to_nodes["tenant_1"]) == 3 + and tree.tenant_to_char_count["tenant_1"] == 5 + ) # 2. Remove more than min_remove_size characters tree._reset() @@ -389,7 +398,10 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: # Before eviction char_count_before = tree.tenant_to_char_count["tenant_1"] - assert len(tree.tenant_to_nodes["tenant_1"]) == 4 and tree.tenant_to_char_count["tenant_1"] == 6 + assert ( + len(tree.tenant_to_nodes["tenant_1"]) == 4 + and tree.tenant_to_char_count["tenant_1"] == 6 + ) # During eviction min_remove_size = 2 @@ -399,7 +411,10 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: char_count_after = tree.tenant_to_char_count["tenant_1"] assert evicted_count != min_remove_size and evicted_count == 3 assert char_count_before - char_count_after == evicted_count - assert len(tree.tenant_to_nodes["tenant_1"]) == 2 and tree.tenant_to_char_count["tenant_1"] == 3 + assert ( + len(tree.tenant_to_nodes["tenant_1"]) == 2 + and tree.tenant_to_char_count["tenant_1"] == 3 + ) # 3. Test eviction of non-existent tenant is idempotent tree._reset() @@ -437,58 +452,83 @@ def test_evict_tenant_by_lru(tree: PrefixTree) -> None: # Test tree structure - validate each node # Root node - assert root.text == "" and root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert root.text == "" and root.tenant_to_last_access_time == { + "tenant_1": 1, + "tenant_2": 3, + } assert "h" in root.edge_label_to_child # Hello node hello_node: Node = root.edge_label_to_child["h"] - assert hello_node.text == "hello" and hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} - assert "w" in hello_node.edge_label_to_child and "t" in hello_node.edge_label_to_child + assert hello_node.text == "hello" and hello_node.tenant_to_last_access_time == { + "tenant_1": 1, + "tenant_2": 3, + } + assert ( + "w" in hello_node.edge_label_to_child and "t" in hello_node.edge_label_to_child + ) # World node world_node: Node = hello_node.edge_label_to_child["w"] - assert world_node.text == "world" and world_node.tenant_to_last_access_time == {"tenant_1": 1} + assert world_node.text == "world" and world_node.tenant_to_last_access_time == { + "tenant_1": 1 + } assert len(world_node.edge_label_to_child) == 0 # Th node th_node: Node = hello_node.edge_label_to_child["t"] - assert th_node.text == "th" and th_node.tenant_to_last_access_time == {"tenant_2": 3} + assert th_node.text == "th" and th_node.tenant_to_last_access_time == { + "tenant_2": 3 + } assert "e" in th_node.edge_label_to_child and "o" in th_node.edge_label_to_child # Ere node ere_node: Node = th_node.edge_label_to_child["e"] - assert ere_node.text == "ere" and ere_node.tenant_to_last_access_time == {"tenant_2": 2} + assert ere_node.text == "ere" and ere_node.tenant_to_last_access_time == { + "tenant_2": 2 + } assert len(ere_node.edge_label_to_child) == 0 # Omas node omas_node: Node = th_node.edge_label_to_child["o"] - assert omas_node.text == "omas" and omas_node.tenant_to_last_access_time == {"tenant_2": 3} + assert omas_node.text == "omas" and omas_node.tenant_to_last_access_time == { + "tenant_2": 3 + } assert len(omas_node.edge_label_to_child) == 0 # Test PrefixTree instance variables assert set(tree.tenant_to_nodes.keys()) == {"tenant_1", "tenant_2"} - + # Test tenant_to_nodes (check by text) - tenant1_nodes_texts: Set[str] = {node.text for node in tree.tenant_to_nodes["tenant_1"]} + tenant1_nodes_texts: Set[str] = { + node.text for node in tree.tenant_to_nodes["tenant_1"] + } assert tenant1_nodes_texts == {"", "hello", "world"} - tenant2_nodes_texts: Set[str] = {node.text for node in tree.tenant_to_nodes["tenant_2"]} + tenant2_nodes_texts: Set[str] = { + node.text for node in tree.tenant_to_nodes["tenant_2"] + } assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} # Test tenant_to_char_count # Before evictions - assert tree.tenant_to_char_count["tenant_1"] == 10 and tree.tenant_to_char_count["tenant_2"] == 14 - + assert ( + tree.tenant_to_char_count["tenant_1"] == 10 + and tree.tenant_to_char_count["tenant_2"] == 14 + ) + # After evicting tenant_1 with min_remove_size=1 # Should remove both "hello" and "world" nodes (10 chars) since they have the same timestamp evicted_count = tree.evict_tenant_by_lru("tenant_1", 1) assert evicted_count == 10 and tree.tenant_to_char_count["tenant_1"] == 0 - + # After evicting tenant_2 with min_remove_size=1 # Should remove "ere" node (3 chars) since it has the oldest timestamp (2) evicted_count = tree.evict_tenant_by_lru("tenant_2", 1) - assert evicted_count == 3 and tree.tenant_to_char_count["tenant_2"] == 11 # 14 - 3 = 11 - + assert ( + evicted_count == 3 and tree.tenant_to_char_count["tenant_2"] == 11 + ) # 14 - 3 = 11 + # After evicting tenant_2 again with min_remove_size=1 # Should remove "hello", "th", and "omas" nodes (11 chars) since they all have timestamp 3 evicted_count = tree.evict_tenant_by_lru("tenant_2", 1) From c0bb33b43f82f2a2726f0b72568aaa597c0da479 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Wed, 7 May 2025 14:37:56 -0700 Subject: [PATCH 13/15] Edit comments Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 37 ++++++++++--------- 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index 9a026dfebe25e..6fa94a0ac842c 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -38,10 +38,10 @@ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: str, Node ] = {} # Maps first character to child node self.tenant_to_last_access_time: Dict[ - str, int + str, float ] = ( {} - ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in milliseconds) + ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in seconds) class TimestampedNode: @@ -116,7 +116,7 @@ def __init__(self) -> None: str, int ] = ( {} - ) # Tracks total character count per tenant. Used by the client to determine which tenant to evict, and by how much. + ) # Tracks total character count per tenant. Used by the replica scheduler to determine which tenant to evict, and by how much. self.tenant_to_nodes: Dict[ str, Set[Node] ] = ( @@ -169,14 +169,6 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: """ Remove a tenant from a single node. - This function expects valid input where: - - tenant exists in self.tenant_to_nodes - - tenant exists in node.tenant_to_last_access_time - - node exists in self.tenant_to_nodes[tenant] - - These preconditions are guaranteed to be satisfied if the user is using the public methods of this class. - They may be violated if the user manipulates the internal state of the tree directly. - Args: tenant: Tenant to remove node: Node to remove tenant from @@ -291,6 +283,7 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: remaining_text: str = matched_node.text[shared_count:] # Create new intermediate node + # Note that we don't update new_parent.tenant_to_last_access_time yet; it will be done at the beginning of the next iteration. new_parent: Node = Node(text=matched_text, parent=curr_node) new_parent.tenant_to_last_access_time = ( matched_node.tenant_to_last_access_time.copy() @@ -327,10 +320,14 @@ def prefix_match( available_tenants: List of tenants to match against (or None for all) Returns: - Tuple of (matched_text, matched_tenants) - - If the list of available tenants doesn't match any tenants in the tree: returns ("", None) - - When no prefix match is found (does not traverse further than the root node): returns ("", list of available tenants) - - When a prefix match is found: returns (matched_prefix, list of tenants that own the matched node) + Tuple of (matched_text, matched_tenants): + If the list of available tenants doesn't match any tenants in the tree: returns ("", None) + When no prefix match is found (does not traverse further than the root node): returns ("", list of available tenants) + When a prefix match is found: returns (matched_prefix, list of tenants that own the matched node) + + Note: + A tenant is unable to be returned by prefix_match until it has inserted text into the tree, even if _add_tenant is called. + The replica scheduler is responsible for inserting text into new replicas; it should not only rely on prefix_match to select replicas. """ if available_tenants: # Filter available_tenants to only include those in the tree @@ -388,6 +385,7 @@ def prefix_match( def remove_tenant(self, tenant: str) -> int: """ Remove a tenant and all its nodes from the tree. + Time complexity: O(n) where n is the number of nodes owned by the tenant. Args: tenant: Tenant to remove @@ -412,6 +410,7 @@ def remove_tenant(self, tenant: str) -> int: def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: """ Evict least recently used nodes for a tenant until minimum size is freed. + Time complexity: O(n + m log n) where n is the number of nodes owned by the tenant, and m is the number of nodes removed. Args: tenant: The tenant to evict nodes from @@ -423,7 +422,7 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: Note: - All nodes with the same oldest access time are removed together to maintain tree integrity, even if only removing a subset of them satisfies the min_remove_size. - This behavior is expected in the case when an input was split into multiple nodes by a different tenant (e.g. insert("helloworld", "tenant_1", 1) and insert("hellothere", "tenant_2", 2)). - because there is no reason to only remove "world" from tenant 1. So we remove the "chain" of "hello" and "world" from tenant 1. + because "hello" and "world" were inserted as a package, and so should be removed as a package. - However, if two inputs happen to be inserted at the same time (e.g. insert("helloworld", "tenant_1", 1) and insert("hellothere", "tenant_2", 1)), then both "chains" will be removed by our method. This may not reflect the actual KV cache eviction policy. - For more predictable eviction, use unique timestamps for each insertion. @@ -450,7 +449,7 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: for node in self.tenant_to_nodes[tenant]: access_time = node.tenant_to_last_access_time[tenant] nodes_by_access_time.append(TimestampedNode(node, access_time)) - heapq.heapify(nodes_by_access_time) + heapq.heapify(nodes_by_access_time) # O(n) # Remove nodes until we've freed enough characters while total_chars_removed < min_remove_size and nodes_by_access_time: @@ -463,7 +462,9 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: nodes_by_access_time and nodes_by_access_time[0].time_sec == oldest_access_time ): - node_to_remove = heapq.heappop(nodes_by_access_time).node + node_to_remove = heapq.heappop( + nodes_by_access_time + ).node # O(log n) total_chars_removed += self._remove_tenant_single_node( tenant, node_to_remove ) From 42d6938ecbbd2600abe8ec659570742de057b5f5 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Thu, 8 May 2025 16:58:18 -0700 Subject: [PATCH 14/15] Doubly linked list instead of min-heap, don't have insert return Node because of pickle error Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 150 ++++++++++++++---- .../serve/cpu/deployments/test_prefix_tree.py | 2 +- 2 files changed, 120 insertions(+), 32 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index 6fa94a0ac842c..9a76f5d151ffe 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -1,6 +1,5 @@ from __future__ import annotations -import heapq import logging import os from threading import RLock @@ -42,6 +41,13 @@ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: ] = ( {} ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in seconds) + # Doubly linked list pointers for LRU tracking per tenant + self.tenant_to_older_node: Dict[ + str, Optional[Node] + ] = {} # Points to the less recently used node (toward tail for eviction) + self.tenant_to_newer_node: Dict[ + str, Optional[Node] + ] = {} # Points to the more recently used node (toward head for retention) class TimestampedNode: @@ -123,6 +129,10 @@ def __init__(self) -> None: {} ) # Maps tenant to set of nodes. Used for O(1) testing if a node belongs to a tenant. The keys are the active tenants in the tree. + # LRU tracking - head is the most recently used node, tail is the least recently used + self.tenant_to_lru_head: Dict[str, Optional[Node]] = {} + self.tenant_to_lru_tail: Dict[str, Optional[Node]] = {} + @staticmethod def _shared_prefix_count(a: str, b: str) -> int: """ @@ -147,6 +157,8 @@ def _reset(self) -> None: self.root = Node() self.tenant_to_char_count = {} self.tenant_to_nodes = {} + self.tenant_to_lru_head = {} + self.tenant_to_lru_tail = {} def _add_tenant(self, tenant: str) -> None: """ @@ -164,6 +176,57 @@ def _add_tenant(self, tenant: str) -> None: self.tenant_to_char_count[tenant] = 0 self.tenant_to_nodes[tenant] = set() + self.tenant_to_lru_head[tenant] = None + self.tenant_to_lru_tail[tenant] = None + + def _move_node_to_head(self, node: Node, tenant: str) -> None: + """ + Move a node to the head of the tenant's LRU list. + + Args: + node: Node to move + tenant: Tenant that owns the node + """ + # If this is the first node, initialize the LRU list + if self.tenant_to_lru_head.get(tenant) is None: + self.tenant_to_lru_head[tenant] = node + self.tenant_to_lru_tail[tenant] = node + node.tenant_to_older_node[tenant] = None + node.tenant_to_newer_node[tenant] = None + return + + # If node is already the head, nothing to do + if node == self.tenant_to_lru_head[tenant]: + return + + # If node is already in the list, remove it + if tenant in node.tenant_to_older_node or tenant in node.tenant_to_newer_node: + # Connect older and newer nodes directly (skip this node) + older = node.tenant_to_older_node.get( + tenant + ) # Less recently used (toward tail) + newer = node.tenant_to_newer_node.get( + tenant + ) # More recently used (toward head) + + if older: + older.tenant_to_newer_node[tenant] = newer + + if newer: + newer.tenant_to_older_node[tenant] = older + + # If this is the tail, update tail pointer + if node == self.tenant_to_lru_tail[tenant]: + self.tenant_to_lru_tail[tenant] = newer + + # Place at head of list + current_head = self.tenant_to_lru_head[tenant] + node.tenant_to_newer_node[tenant] = None # Head has no newer node + node.tenant_to_older_node[ + tenant + ] = current_head # Old head becomes older than new head + current_head.tenant_to_newer_node[tenant] = node # Connect old head to new head + self.tenant_to_lru_head[tenant] = node # Update head pointer def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: """ @@ -196,6 +259,29 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: self.tenant_to_nodes[tenant].remove(node) node.tenant_to_last_access_time.pop(tenant, None) + # Remove from LRU list + older = node.tenant_to_older_node.get( + tenant + ) # Less recently used (toward tail) + newer = node.tenant_to_newer_node.get( + tenant + ) # More recently used (toward head) + + if older: + older.tenant_to_newer_node[tenant] = newer + + if newer: + newer.tenant_to_older_node[tenant] = older + + # Update head/tail pointers if necessary + if node == self.tenant_to_lru_head[tenant]: + self.tenant_to_lru_head[tenant] = older # Older becomes new head + if node == self.tenant_to_lru_tail[tenant]: + self.tenant_to_lru_tail[tenant] = newer # Newer becomes new tail + + node.tenant_to_older_node.pop(tenant, None) + node.tenant_to_newer_node.pop(tenant, None) + # Clean up empty nodes if not node.tenant_to_last_access_time and node.parent: if ( @@ -205,7 +291,7 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: return removed_chars_len - def insert(self, text: str, tenant: str, time_sec: float) -> Node: + def insert(self, text: str, tenant: str, time_sec: float) -> None: """ Insert text into tree for a specific tenant. @@ -216,9 +302,6 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: tenant: Tenant time_sec: Current timestamp in seconds - Returns: - The node that was inserted or updated - Loop structure: 1. At the start of each iteration, curr_node is a node we potentially update. e.g. Update node.tenant_to_last_access_time[tenant], self.tenant_to_char_count, @@ -246,6 +329,7 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: self.tenant_to_nodes[tenant].add(curr_node) curr_node.tenant_to_last_access_time[tenant] = time_sec + self._move_node_to_head(curr_node, tenant) if i == len(text): break @@ -290,6 +374,11 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: ) for existing_tenant in new_parent.tenant_to_last_access_time: self.tenant_to_nodes[existing_tenant].add(new_parent) + # Initialize LRU list pointers + new_parent.tenant_to_older_node[existing_tenant] = None + new_parent.tenant_to_newer_node[existing_tenant] = None + # Move to head of LRU list for each tenant + self._move_node_to_head(new_parent, existing_tenant) # Update existing matched node matched_node.text = remaining_text @@ -307,8 +396,6 @@ def insert(self, text: str, tenant: str, time_sec: float) -> Node: curr_node = matched_node i += shared_count - return curr_node - def prefix_match( self, text: str, available_tenants: Optional[List[str]] = None ) -> Tuple[str, Optional[List[str]]]: @@ -404,13 +491,15 @@ def remove_tenant(self, tenant: str) -> int: self.tenant_to_nodes.pop(tenant, None) self.tenant_to_char_count.pop(tenant, None) + self.tenant_to_lru_head.pop(tenant, None) + self.tenant_to_lru_tail.pop(tenant, None) return total_chars_removed def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: """ Evict least recently used nodes for a tenant until minimum size is freed. - Time complexity: O(n + m log n) where n is the number of nodes owned by the tenant, and m is the number of nodes removed. + Time complexity: O(m) where m is the number of nodes removed. Args: tenant: The tenant to evict nodes from @@ -443,31 +532,30 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: total_chars_removed: int = 0 - # Create a min-heap of nodes ordered by access time - # Each entry is a TimestampedNode(node, access_time) object, which has a __lt__ method that is used by heapq. - nodes_by_access_time = [] - for node in self.tenant_to_nodes[tenant]: - access_time = node.tenant_to_last_access_time[tenant] - nodes_by_access_time.append(TimestampedNode(node, access_time)) - heapq.heapify(nodes_by_access_time) # O(n) - - # Remove nodes until we've freed enough characters - while total_chars_removed < min_remove_size and nodes_by_access_time: - # Get the oldest (minimum) access time from the top of the heap - oldest_access_time = nodes_by_access_time[0].time_sec - - # Remove ALL nodes with this same access time to maintain tree consistency - # (partial removals could break prefix relationships) + # Start removing from the tail (least recently used) + tail = self.tenant_to_lru_tail.get(tenant) + + # Continue until we've freed enough space or run out of nodes + while total_chars_removed < min_remove_size and tail: + # Get the current timestamp to remove all nodes with this timestamp + current_timestamp = tail.tenant_to_last_access_time[tenant] + nodes_with_same_timestamp = [] + + # Collect all nodes with the same timestamp (guaranteed to be contiguous in our LRU list) + current = tail while ( - nodes_by_access_time - and nodes_by_access_time[0].time_sec == oldest_access_time + current + and current.tenant_to_last_access_time[tenant] == current_timestamp ): - node_to_remove = heapq.heappop( - nodes_by_access_time - ).node # O(log n) - total_chars_removed += self._remove_tenant_single_node( - tenant, node_to_remove - ) + nodes_with_same_timestamp.append(current) + current = current.tenant_to_newer_node.get(tenant) + + # Set the new tail to continue from for the next iteration (if needed) + tail = current + + # Remove all collected nodes with the same timestamp + for node in nodes_with_same_timestamp: + total_chars_removed += self._remove_tenant_single_node(tenant, node) return total_chars_removed diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index 0e39859fa85c8..cebfc847923ad 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -276,7 +276,7 @@ def test__remove_tenant_single_node(tree: PrefixTree) -> None: # 1. Test removing a single node tree._reset() tree.insert("hello", "tenant_1", 1) - h_node: Node = tree.insert("hello", "tenant_1", 1) + h_node: Node = tree.root.edge_label_to_child["h"] removed: int = tree._remove_tenant_single_node("tenant_1", h_node) assert removed == 5 From 0e0bb821db62c86846c046e19f7fe06184c19c49 Mon Sep 17 00:00:00 2001 From: Justin Ji Date: Mon, 12 May 2025 12:21:40 -0700 Subject: [PATCH 15/15] Fix LRU linked list implementation and clean up tests Signed-off-by: Justin Ji --- .../prefix_aware/prefix_tree.py | 378 +++--- .../serve/cpu/deployments/test_prefix_tree.py | 1196 +++++++++-------- 2 files changed, 822 insertions(+), 752 deletions(-) diff --git a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py index 9a76f5d151ffe..3cba4f5223d89 100644 --- a/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py +++ b/python/ray/llm/_internal/serve/replica_scheduler/prefix_aware/prefix_tree.py @@ -3,7 +3,7 @@ import logging import os from threading import RLock -from typing import Dict, List, Optional, Set, Tuple, Any +from typing import Any, Dict, List, Optional, Tuple import ray @@ -32,52 +32,17 @@ def __init__(self, text: str = "", parent: Optional[Node] = None) -> None: parent: The parent node of this node """ self.text: str = text - self.parent: Optional[Node] = parent # The parent node of this node - self.edge_label_to_child: Dict[ - str, Node - ] = {} # Maps first character to child node - self.tenant_to_last_access_time: Dict[ - str, float - ] = ( - {} - ) # For each tenant that has inserted text matching this node, maps tenant to the last access timestamp (in seconds) - # Doubly linked list pointers for LRU tracking per tenant - self.tenant_to_older_node: Dict[ - str, Optional[Node] - ] = {} # Points to the less recently used node (toward tail for eviction) - self.tenant_to_newer_node: Dict[ - str, Optional[Node] - ] = {} # Points to the more recently used node (toward head for retention) - - -class TimestampedNode: - """ - Wrapper class for storing nodes in a min-heap, ordered by timestamp. - Used for efficient LRU eviction of nodes. - """ - - def __init__(self, node: Node, time_sec: float) -> None: - """ - Initialize a heap node for efficient LRU eviction of nodes. - - Args: - node: The prefix tree node this heap node refers to - time_sec: The timestamp this heap uses to order nodes - """ - self.node = node - self.time_sec = time_sec + self.parent: Optional[Node] = parent - def __lt__(self, other: TimestampedNode) -> bool: - """ - Compare heap nodes based on timestamp. - - Args: - other: Another TimestampedNode to compare with - - Returns: - True if this node's timestamp is earlier than the other's - """ - return self.time_sec < other.time_sec + # Maps first character to child node + self.edge_label_to_child: Dict[str, Node] = {} + # For each tenant that has inserted text matching this node, track its last access timestamp (in seconds) + self.tenant_to_last_access_time: Dict[str, float] = {} + # Doubly linked list pointers for LRU tracking per tenant + # Points to the less recently used node (toward tail) + self.tenant_to_older_node: Dict[str, Optional[Node]] = {} + # Points to the more recently used node (toward head) + self.tenant_to_newer_node: Dict[str, Optional[Node]] = {} class PrefixTree: @@ -111,26 +76,21 @@ class PrefixTree: PrefixTree instance variables: self.tenant_to_char_count = {"tenant_1": 10, "tenant_2": 14} - self.tenant_to_nodes = {"tenant_1": {root, Node("hello"), Node("world")}, "tenant_2": {root, Node("hello"), Node("th"), Node("ere"), Node("omas")}} + self.tenant_to_lru_tail = {"tenant_1": Node("world"), "tenant_2": Node("ere")} """ def __init__(self) -> None: """Initialize an empty prefix tree.""" self.lock: RLock = RLock() + + # Root is always the head of the LRU list for each tenant. self.root: Node = Node() - self.tenant_to_char_count: Dict[ - str, int - ] = ( - {} - ) # Tracks total character count per tenant. Used by the replica scheduler to determine which tenant to evict, and by how much. - self.tenant_to_nodes: Dict[ - str, Set[Node] - ] = ( - {} - ) # Maps tenant to set of nodes. Used for O(1) testing if a node belongs to a tenant. The keys are the active tenants in the tree. - - # LRU tracking - head is the most recently used node, tail is the least recently used - self.tenant_to_lru_head: Dict[str, Optional[Node]] = {} + + # Tracks total character count per tenant. Can be used by the replica scheduler to determine which tenant to evict, and by how much. + # Also uses the keys to track the active tenants in the tree. + self.tenant_to_char_count: Dict[str, int] = {} + + # LRU tracking - root is always the head, tail is the least recently used. self.tenant_to_lru_tail: Dict[str, Optional[Node]] = {} @staticmethod @@ -147,18 +107,18 @@ def _shared_prefix_count(a: str, b: str) -> int: """ return len(os.path.commonprefix([a, b])) - def _reset(self) -> None: + def _get_lru_chain(self, tenant: str) -> List[Node]: """ - Reset the tree to an empty state. - + Get the LRU chain for a given tenant by traversing from the root to the oldest node. Note: This method is intended to be used only in tests. """ with self.lock: - self.root = Node() - self.tenant_to_char_count = {} - self.tenant_to_nodes = {} - self.tenant_to_lru_head = {} - self.tenant_to_lru_tail = {} + nodes = [] + current_node = self.root + while current_node: + nodes.append(current_node) + current_node = current_node.tenant_to_older_node.get(tenant) + return nodes def _add_tenant(self, tenant: str) -> None: """ @@ -170,63 +130,78 @@ def _add_tenant(self, tenant: str) -> None: tenant: Tenant to add """ with self.lock: - if tenant in self.tenant_to_nodes: + if tenant in self.tenant_to_char_count: logger.warning(f"Tenant '{tenant}' already exists. No action taken.") return self.tenant_to_char_count[tenant] = 0 - self.tenant_to_nodes[tenant] = set() - self.tenant_to_lru_head[tenant] = None - self.tenant_to_lru_tail[tenant] = None - - def _move_node_to_head(self, node: Node, tenant: str) -> None: + self.tenant_to_lru_tail[tenant] = self.root + + # Initialize the root node as the head of the LRU list for this tenant + self.root.tenant_to_newer_node[tenant] = None + self.root.tenant_to_older_node[tenant] = None + + def _insert_node_into_linked_list( + self, + node: Node, + newer_neighbor: Optional[Node], + older_neighbor: Optional[Node], + tenant: str, + ) -> None: """ - Move a node to the head of the tenant's LRU list. + Insert a node into the LRU list between two neighbors. Updates the neighbors' pointers and the tail pointer, if that changes. + """ + with self.lock: + if tenant not in self.tenant_to_char_count: + logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") + return - Args: - node: Node to move - tenant: Tenant that owns the node + # Skip if node is the root + if node == self.root: + return + + node.tenant_to_newer_node[tenant] = newer_neighbor + node.tenant_to_older_node[tenant] = older_neighbor + + if newer_neighbor: + newer_neighbor.tenant_to_older_node[tenant] = node + + if older_neighbor: + older_neighbor.tenant_to_newer_node[tenant] = node + + if self.tenant_to_lru_tail[tenant] == newer_neighbor: + self.tenant_to_lru_tail[tenant] = node + + def _remove_node_from_linked_list(self, node: Node, tenant: str) -> None: """ - # If this is the first node, initialize the LRU list - if self.tenant_to_lru_head.get(tenant) is None: - self.tenant_to_lru_head[tenant] = node - self.tenant_to_lru_tail[tenant] = node - node.tenant_to_older_node[tenant] = None - node.tenant_to_newer_node[tenant] = None - return - - # If node is already the head, nothing to do - if node == self.tenant_to_lru_head[tenant]: - return - - # If node is already in the list, remove it - if tenant in node.tenant_to_older_node or tenant in node.tenant_to_newer_node: - # Connect older and newer nodes directly (skip this node) - older = node.tenant_to_older_node.get( - tenant - ) # Less recently used (toward tail) - newer = node.tenant_to_newer_node.get( - tenant - ) # More recently used (toward head) + Remove a node from the LRU list. Updates the neighbors' pointers and the tail pointer, if that changes. + """ + with self.lock: + if tenant not in self.tenant_to_char_count: + logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") + return - if older: - older.tenant_to_newer_node[tenant] = newer + # Skip if node is the root + if node == self.root: + return - if newer: - newer.tenant_to_older_node[tenant] = older + # Connect older and newer neighbors + older_neighbor = node.tenant_to_older_node.get(tenant) + newer_neighbor = node.tenant_to_newer_node.get(tenant) - # If this is the tail, update tail pointer - if node == self.tenant_to_lru_tail[tenant]: - self.tenant_to_lru_tail[tenant] = newer + if older_neighbor: + older_neighbor.tenant_to_newer_node[tenant] = newer_neighbor - # Place at head of list - current_head = self.tenant_to_lru_head[tenant] - node.tenant_to_newer_node[tenant] = None # Head has no newer node - node.tenant_to_older_node[ - tenant - ] = current_head # Old head becomes older than new head - current_head.tenant_to_newer_node[tenant] = node # Connect old head to new head - self.tenant_to_lru_head[tenant] = node # Update head pointer + if newer_neighbor: + newer_neighbor.tenant_to_older_node[tenant] = older_neighbor + + # Update tail pointer if necessary + if self.tenant_to_lru_tail[tenant] == node: + self.tenant_to_lru_tail[tenant] = newer_neighbor + + # Remove node from list + node.tenant_to_newer_node.pop(tenant, None) + node.tenant_to_older_node.pop(tenant, None) def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: """ @@ -240,7 +215,7 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: Number of characters removed (0 if preconditions not met) """ with self.lock: - if tenant not in self.tenant_to_nodes: + if tenant not in self.tenant_to_char_count: logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") return 0 if tenant not in node.tenant_to_last_access_time: @@ -248,39 +223,12 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: f"Tenant '{tenant}' does not have node '{node.text}'. No action taken." ) return 0 - if node not in self.tenant_to_nodes[tenant]: - logger.warning( - f"Node '{node.text}' does not belong to tenant '{tenant}'. No action taken." - ) - return 0 removed_chars_len: int = len(node.text) self.tenant_to_char_count[tenant] -= removed_chars_len - self.tenant_to_nodes[tenant].remove(node) node.tenant_to_last_access_time.pop(tenant, None) - # Remove from LRU list - older = node.tenant_to_older_node.get( - tenant - ) # Less recently used (toward tail) - newer = node.tenant_to_newer_node.get( - tenant - ) # More recently used (toward head) - - if older: - older.tenant_to_newer_node[tenant] = newer - - if newer: - newer.tenant_to_older_node[tenant] = older - - # Update head/tail pointers if necessary - if node == self.tenant_to_lru_head[tenant]: - self.tenant_to_lru_head[tenant] = older # Older becomes new head - if node == self.tenant_to_lru_tail[tenant]: - self.tenant_to_lru_tail[tenant] = newer # Newer becomes new tail - - node.tenant_to_older_node.pop(tenant, None) - node.tenant_to_newer_node.pop(tenant, None) + self._remove_node_from_linked_list(node, tenant) # Clean up empty nodes if not node.tenant_to_last_access_time and node.parent: @@ -291,21 +239,20 @@ def _remove_tenant_single_node(self, tenant: str, node: Node) -> int: return removed_chars_len - def insert(self, text: str, tenant: str, time_sec: float) -> None: + def insert(self, text: str, tenant: str, time_s: float) -> None: """ Insert text into tree for a specific tenant. - If the tenant doesn't exist, it will be automatically added. + If the tenant doesn't already exist in the tree, it will be automatically added. Args: text: Text to insert tenant: Tenant - time_sec: Current timestamp in seconds + time_s: Current timestamp in seconds Loop structure: - 1. At the start of each iteration, curr_node is a node we potentially update. - e.g. Update node.tenant_to_last_access_time[tenant], self.tenant_to_char_count, - self.tenant_to_nodes + 1. We update the current node at the start of each iteration of the while loop. + This includes updating tenant_to_char_count and tenant_to_last_access_time, and moving the node to the front of the LRU list. 2. Each iteration then either: a. Breaks (if we've processed the entire string). b. Processes the next segment of text by: @@ -315,22 +262,26 @@ def insert(self, text: str, tenant: str, time_sec: float) -> None: b. If they fully match, traverse into the child node. """ with self.lock: - if tenant not in self.tenant_to_nodes: + if tenant not in self.tenant_to_char_count: self._add_tenant(tenant) curr_node: Node = self.root i: int = 0 - while i <= len(text): # Invariant at beginning of each iteration: assume curr_node has not been visited by tenant yet. # Update tenant info for current node. if tenant not in curr_node.tenant_to_last_access_time: self.tenant_to_char_count[tenant] += len(curr_node.text) - self.tenant_to_nodes[tenant].add(curr_node) - - curr_node.tenant_to_last_access_time[tenant] = time_sec - self._move_node_to_head(curr_node, tenant) + curr_node.tenant_to_last_access_time[tenant] = time_s + if curr_node != self.root: + self._remove_node_from_linked_list(curr_node, tenant) + self._insert_node_into_linked_list( + curr_node, + self.root, + self.root.tenant_to_older_node.get(tenant), + tenant, + ) if i == len(text): break @@ -342,13 +293,16 @@ def insert(self, text: str, tenant: str, time_sec: float) -> None: # e.g. curr_node.edge_label_to_child = {}, curr_text = "hello" -> curr_node.edge_label_to_child = {"h": Node("hello")} new_node: Node = Node(text=curr_text, parent=curr_node) curr_node.edge_label_to_child[first_char] = new_node + # Add the node to the back of the LRU list; it will be moved to the front in the next iteration. + self._insert_node_into_linked_list( + new_node, self.tenant_to_lru_tail[tenant], None, tenant + ) # Match found, check if we need to split matched_node: Node = curr_node.edge_label_to_child[first_char] shared_count: int = self._shared_prefix_count( matched_node.text, curr_text ) - if shared_count < len(matched_node.text): # Partial match, split node at matched point # Example: @@ -360,6 +314,8 @@ def insert(self, text: str, tenant: str, time_sec: float) -> None: ### curr_node.edge_label_to_child = {"h": Node("hello", edge_label_to_child = {"w": Node("world")})} ### parent_node = Node("hello"), matched_node = Node("world") ### Copy matched_node.tenant_to_last_access_time to parent_node.tenant_to_last_access_time + ### Insert parent_node into the back of the LRU list; it will be moved to the front in the next iteration. (for the current tenant) + ### Insert parent_node between matched_node and matched_node's newer neighbor (for all other tenants) ### (new) curr_text = "there", (new) curr_node = parent_node ### Continue adding "there" to tree in next iteration @@ -372,13 +328,19 @@ def insert(self, text: str, tenant: str, time_sec: float) -> None: new_parent.tenant_to_last_access_time = ( matched_node.tenant_to_last_access_time.copy() ) + # Insert new_parent into the back of the LRU list; it will be moved to the front in the next iteration. (for the current tenant) + self._insert_node_into_linked_list( + new_parent, self.tenant_to_lru_tail[tenant], None, tenant + ) + # Insert new_parent between matched_node and matched_node's newer neighbor (for all other tenants) for existing_tenant in new_parent.tenant_to_last_access_time: - self.tenant_to_nodes[existing_tenant].add(new_parent) - # Initialize LRU list pointers - new_parent.tenant_to_older_node[existing_tenant] = None - new_parent.tenant_to_newer_node[existing_tenant] = None - # Move to head of LRU list for each tenant - self._move_node_to_head(new_parent, existing_tenant) + if existing_tenant != tenant: + self._insert_node_into_linked_list( + new_parent, + matched_node.tenant_to_newer_node.get(existing_tenant), + matched_node, + existing_tenant, + ) # Update existing matched node matched_node.text = remaining_text @@ -416,17 +378,19 @@ def prefix_match( A tenant is unable to be returned by prefix_match until it has inserted text into the tree, even if _add_tenant is called. The replica scheduler is responsible for inserting text into new replicas; it should not only rely on prefix_match to select replicas. """ - if available_tenants: - # Filter available_tenants to only include those in the tree - available_tenants = [ - tenant for tenant in available_tenants if tenant in self.tenant_to_nodes - ] - if not available_tenants: - return "", None - else: - available_tenants = list(self.tenant_to_nodes.keys()) - with self.lock: + if available_tenants: + # Filter available_tenants to only include those in the tree + available_tenants = [ + tenant + for tenant in available_tenants + if tenant in self.tenant_to_char_count + ] + if not available_tenants: + return "", None + else: + available_tenants = list(self.tenant_to_char_count.keys()) + curr_node: Node = self.root i: int = 0 text_len: int = len(text) @@ -481,17 +445,23 @@ def remove_tenant(self, tenant: str) -> int: Number of characters removed (0 if tenant doesn't exist) """ with self.lock: - if tenant not in self.tenant_to_nodes: + if tenant not in self.tenant_to_char_count: logger.warning(f"Tenant '{tenant}' does not exist. No action taken.") return 0 total_chars_removed: int = 0 - for node in self.tenant_to_nodes[tenant].copy(): - total_chars_removed += self._remove_tenant_single_node(tenant, node) - self.tenant_to_nodes.pop(tenant, None) + # Start from the tail and remove all nodes + current_tail = self.tenant_to_lru_tail.get(tenant) + while current_tail: + newer_neighbor = current_tail.tenant_to_newer_node.get(tenant) + total_chars_removed += self._remove_tenant_single_node( + tenant, current_tail + ) + current_tail = newer_neighbor + + # Clean up tenant references self.tenant_to_char_count.pop(tenant, None) - self.tenant_to_lru_head.pop(tenant, None) self.tenant_to_lru_tail.pop(tenant, None) return total_chars_removed @@ -510,16 +480,13 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: Note: - All nodes with the same oldest access time are removed together to maintain tree integrity, even if only removing a subset of them satisfies the min_remove_size. - - This behavior is expected in the case when an input was split into multiple nodes by a different tenant (e.g. insert("helloworld", "tenant_1", 1) and insert("hellothere", "tenant_2", 2)). - because "hello" and "world" were inserted as a package, and so should be removed as a package. - - However, if two inputs happen to be inserted at the same time (e.g. insert("helloworld", "tenant_1", 1) and insert("hellothere", "tenant_2", 1)), - then both "chains" will be removed by our method. This may not reflect the actual KV cache eviction policy. - For more predictable eviction, use unique timestamps for each insertion. + - The root node is never evicted as it serves as the permanent head of the LRU list. """ with self.lock: - if tenant not in self.tenant_to_nodes or not self.tenant_to_nodes[tenant]: + if tenant not in self.tenant_to_char_count: logger.warning( - f"Cannot evict tenant '{tenant}': tenant does not exist or has no nodes. No action taken." + f"Cannot evict tenant '{tenant}': tenant does not exist. No action taken." ) return 0 @@ -533,29 +500,28 @@ def evict_tenant_by_lru(self, tenant: str, min_remove_size: int) -> int: total_chars_removed: int = 0 # Start removing from the tail (least recently used) - tail = self.tenant_to_lru_tail.get(tenant) + current_tail = self.tenant_to_lru_tail.get(tenant) # Continue until we've freed enough space or run out of nodes - while total_chars_removed < min_remove_size and tail: + while total_chars_removed < min_remove_size and current_tail: + # Stop if we've reached the root - the root is never evicted + if current_tail == self.root: + break + # Get the current timestamp to remove all nodes with this timestamp - current_timestamp = tail.tenant_to_last_access_time[tenant] - nodes_with_same_timestamp = [] + current_timestamp = current_tail.tenant_to_last_access_time[tenant] # Collect all nodes with the same timestamp (guaranteed to be contiguous in our LRU list) - current = tail while ( - current - and current.tenant_to_last_access_time[tenant] == current_timestamp + current_tail != self.root # Never include the root + and current_tail.tenant_to_last_access_time[tenant] + == current_timestamp ): - nodes_with_same_timestamp.append(current) - current = current.tenant_to_newer_node.get(tenant) - - # Set the new tail to continue from for the next iteration (if needed) - tail = current - - # Remove all collected nodes with the same timestamp - for node in nodes_with_same_timestamp: - total_chars_removed += self._remove_tenant_single_node(tenant, node) + newer_neighbor = current_tail.tenant_to_newer_node.get(tenant) + total_chars_removed += self._remove_tenant_single_node( + tenant, current_tail + ) + current_tail = newer_neighbor return total_chars_removed @@ -579,17 +545,9 @@ def get_smallest_tenant(self) -> Optional[str]: @ray.remote class PrefixTreeActor(PrefixTree): - def _to_dict(self) -> Dict[str, Any]: + def getattr(self, attribute: str) -> Any: """ - Convert tree to dictionary for serialization. - - Returns: - Dictionary representation of the tree - + Get an attribute of the PrefixTree. Note: This method is intended to be used only in tests. """ - return { - "root": self.root, - "tenant_to_char_count": self.tenant_to_char_count, - "tenant_to_nodes": self.tenant_to_nodes, - } + return getattr(self, attribute) diff --git a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py index cebfc847923ad..1a84364896669 100644 --- a/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py +++ b/python/ray/llm/tests/serve/cpu/deployments/test_prefix_tree.py @@ -1,6 +1,6 @@ import pytest import ray -from typing import Set, List, Dict, Optional +from typing import Set, List from ray.llm._internal.serve.replica_scheduler.prefix_aware.prefix_tree import ( PrefixTree, @@ -12,556 +12,668 @@ # Fixtures @pytest.fixture def tree() -> PrefixTree: - """Create a fresh PrefixTree instance for each test.""" + """Create a fresh PrefixTree instance for each local test.""" return PrefixTree() -@pytest.fixture(scope="module") +@pytest.fixture def tree_actor(): - """Create a fresh PrefixTreeActor instance for each test.""" - tree_actor = PrefixTreeActor.remote() - return tree_actor + """Create a fresh PrefixTreeActor instance for each ray.remote test.""" + return PrefixTreeActor.remote() + + +# Helper to get LRU chain texts +def get_lru_texts_from_tree(tree: PrefixTree, tenant_id: str) -> List[str]: + """Gets LRU chain texts directly from a PrefixTree instance.""" + chain = tree._get_lru_chain(tenant_id) + return [node.text for node in chain] + + +async def get_lru_texts_from_tree_actor( + tree_actor: PrefixTreeActor, tenant_id: str +) -> List[str]: + """Gets LRU chain texts from a PrefixTreeActor.""" + chain = ray.get(tree_actor._get_lru_chain.remote(tenant_id)) + return [node.text for node in chain] + + +class TestPrefixTreeInitialization: + """Tests for the PrefixTree class initialization and basic tenant management.""" + + def test_initial_state(self, tree: PrefixTree) -> None: + """Test the initial state of a new PrefixTree.""" + assert tree.tenant_to_char_count == {} + assert tree.tenant_to_lru_tail == {} + assert tree.root is not None + assert tree.root.text == "" + assert tree.root.parent is None + assert tree.root.tenant_to_last_access_time == {} + assert tree.root.edge_label_to_child == {} + + def test_add_tenant(self, tree: PrefixTree) -> None: + """Test adding a new tenant via _add_tenant.""" + tree._add_tenant("tenant_1") + assert tree.tenant_to_char_count == {"tenant_1": 0} + assert tree.tenant_to_lru_tail.get("tenant_1") == tree.root + # _add_tenant itself doesn't update root's access time for the tenant. + assert tree.root.tenant_to_last_access_time == {} + assert get_lru_texts_from_tree(tree, "tenant_1") == [""] + + def test_add_existing_tenant_noop(self, tree: PrefixTree) -> None: + """Test that adding an existing tenant via _add_tenant is a no-op.""" + tree._add_tenant("tenant_1") + assert tree.tenant_to_char_count == {"tenant_1": 0} + assert tree.tenant_to_lru_tail.get("tenant_1") == tree.root + assert tree.root.tenant_to_last_access_time == {} + assert get_lru_texts_from_tree(tree, "tenant_1") == [""] + + tree._add_tenant("tenant_1") # Add again + + assert tree.tenant_to_char_count == {"tenant_1": 0} + assert tree.tenant_to_lru_tail.get("tenant_1") == tree.root + assert tree.root.tenant_to_last_access_time == {} + assert get_lru_texts_from_tree(tree, "tenant_1") == [""] + + +class TestPrefixTreeInsert: + def test_insert_single_string(self, tree: PrefixTree) -> None: + """Test inserting a single string, which also adds a new tenant.""" + tree.insert("hello", "tenant_1", 1) + assert tree.tenant_to_char_count == {"tenant_1": 5} + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello"] + + root_node = tree.root + assert root_node.tenant_to_last_access_time == {"tenant_1": 1} + assert set(root_node.edge_label_to_child.keys()) == {"h"} + + hello_node = root_node.edge_label_to_child["h"] + assert hello_node.text == "hello" + assert hello_node.parent == root_node + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1} + assert hello_node.edge_label_to_child == {} + + def test_insert_duplicate_string(self, tree: PrefixTree) -> None: + """Test inserting a duplicate string for the same tenant.""" + tree.insert("hello", "tenant_1", 1) # Initial insert + tree.insert("hello", "tenant_1", 1) # Duplicate insert with the same timestamp + + assert tree.tenant_to_char_count == {"tenant_1": 5} # Char count unchanged + assert get_lru_texts_from_tree(tree, "tenant_1") == [ + "", + "hello", + ] # LRU order same + + hello_node = tree.root.edge_label_to_child["h"] + assert tree.root.tenant_to_last_access_time == {"tenant_1": 1} + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1} + + tree.insert("hello", "tenant_1", 2) # Duplicate insert with new timestamp + + assert tree.tenant_to_char_count == {"tenant_1": 5} # Char count unchanged + assert get_lru_texts_from_tree(tree, "tenant_1") == [ + "", + "hello", + ] # LRU order same + + hello_node = tree.root.edge_label_to_child["h"] + assert tree.root.tenant_to_last_access_time == {"tenant_1": 2} + assert hello_node.tenant_to_last_access_time == {"tenant_1": 2} + + def test_insert_multiple_tenants(self, tree: PrefixTree) -> None: + """Test inserting the same string for different tenants.""" + tree.insert("hello", "tenant_1", 1) + tree.insert("hello", "tenant_2", 2) + + assert tree.tenant_to_char_count == {"tenant_1": 5, "tenant_2": 5} + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello"] + assert get_lru_texts_from_tree(tree, "tenant_2") == ["", "hello"] + + hello_node = tree.root.edge_label_to_child["h"] + assert tree.root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 2} + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 2} + + def test_insert_node_split(self, tree: PrefixTree) -> None: + """Test insertion that causes an existing node to split due to differing suffixes.""" + tree.insert("helloworld", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) # "hello" is common prefix + + assert tree.tenant_to_char_count == {"tenant_1": 10, "tenant_2": 10} + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello", "world"] + assert get_lru_texts_from_tree(tree, "tenant_2") == ["", "there", "hello"] + + hello_node = tree.root.edge_label_to_child["h"] + assert hello_node.text == "hello" + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 2} + assert set(hello_node.edge_label_to_child.keys()) == {"w", "t"} + + world_node = hello_node.edge_label_to_child["w"] + assert world_node.text == "world" + assert world_node.tenant_to_last_access_time == {"tenant_1": 1} + + there_node = hello_node.edge_label_to_child["t"] + assert there_node.text == "there" + assert there_node.tenant_to_last_access_time == {"tenant_2": 2} + + def test_insert_longer_string_with_shared_prefix(self, tree: PrefixTree) -> None: + """Test inserting a longer string that shares a prefix with an existing node string.""" + tree.insert("hello", "tenant_1", 1) + tree.insert("helloworld", "tenant_2", 2) # "hello" is prefix of "helloworld" + + assert tree.tenant_to_char_count == {"tenant_1": 5, "tenant_2": 10} + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello"] + assert get_lru_texts_from_tree(tree, "tenant_2") == ["", "world", "hello"] + + hello_node = tree.root.edge_label_to_child["h"] + assert hello_node.text == "hello" + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 2} + assert set(hello_node.edge_label_to_child.keys()) == {"w"} + + world_node = hello_node.edge_label_to_child["w"] + assert world_node.text == "world" + assert world_node.tenant_to_last_access_time == {"tenant_2": 2} + + # Ensure no empty non-root nodes created + empty_text_nodes: List[Node] = [] + nodes_to_check: List[Node] = [tree.root] + visited_nodes: Set[Node] = {tree.root} + while nodes_to_check: + node: Node = nodes_to_check.pop() + if node.text == "" and node != tree.root: # check for non-root empty nodes + empty_text_nodes.append(node) + for child in node.edge_label_to_child.values(): + if child not in visited_nodes: + nodes_to_check.append(child) + visited_nodes.add(child) + assert not empty_text_nodes + + def test_insert_shorter_string_with_shared_prefix(self, tree: PrefixTree) -> None: + """Test inserting a shorter string that is a prefix of an existing longer string, causing split.""" + tree.insert("helloworld", "tenant_1", 1) + tree.insert( + "hello", "tenant_2", 2 + ) # "hello" is prefix, causes "helloworld" to split + + assert tree.tenant_to_char_count == {"tenant_1": 10, "tenant_2": 5} + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello", "world"] + assert get_lru_texts_from_tree(tree, "tenant_2") == ["", "hello"] + + hello_node = tree.root.edge_label_to_child["h"] + assert hello_node.text == "hello" + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 2} + assert set(hello_node.edge_label_to_child.keys()) == {"w"} + + world_node = hello_node.edge_label_to_child["w"] + assert world_node.text == "world" + assert world_node.tenant_to_last_access_time == {"tenant_1": 1} + + +class TestPrefixTreeMatch: + def test_prefix_match_empty_tree(self, tree: PrefixTree) -> None: + """Test prefix_match on an empty tree returns empty string and None tenants.""" + matched_text, matched_tenants = tree.prefix_match("hello") + assert matched_text == "" + assert matched_tenants is None + + def test_prefix_match_no_match(self, tree: PrefixTree) -> None: + """Test prefix_match for a non-matching prefix returns empty string and all tenants.""" + tree.insert("hello", "tenant_1", 1) + tree.insert("world", "tenant_2", 2) + matched_text, matched_tenants = tree.prefix_match("foobar") + assert matched_text == "" + assert matched_tenants is not None + assert sorted(matched_tenants) == sorted(["tenant_1", "tenant_2"]) + + def test_prefix_match_query_longer_than_stored_strings( + self, tree: PrefixTree + ) -> None: + """Test prefix_match where query is longer than any stored string but matches a full path.""" + tree.insert("helloworld", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) + matched_text, matched_tenants = tree.prefix_match("hellothereextra") + assert matched_text == "hellothere" + assert matched_tenants == ["tenant_2"] + + def test_prefix_match_exact_match(self, tree: PrefixTree) -> None: + """Test prefix_match with an exact match for a single tenant.""" + tree.insert("hello", "tenant_1", 1) + matched_text, matched_tenants = tree.prefix_match("hello") + assert matched_text == "hello" + assert matched_tenants == ["tenant_1"] + + def test_prefix_match_partial_match(self, tree: PrefixTree) -> None: + """Test prefix_match with a partial query matching the longest common part of a branch.""" + tree.insert("apple", "tenant_1", 1) + tree.insert("apricot", "tenant_2", 2) + matched_text, matched_tenants = tree.prefix_match("application") + assert matched_text == "appl" # Longest of ("appl", "ap") + assert matched_tenants == ["tenant_1"] + + def test_prefix_match_with_tenant_filter(self, tree: PrefixTree) -> None: + """Test prefix_match with a tenant filter selecting a specific branch.""" + tree.insert("apple", "tenant_1", 1) + tree.insert("apricot", "tenant_2", 2) + matched_text, matched_tenants = tree.prefix_match("application", ["tenant_2"]) + assert matched_text == "ap" + assert matched_tenants == ["tenant_2"] + + def test_prefix_match_with_non_existent_tenant_filter( + self, tree: PrefixTree + ) -> None: + """Test prefix_match with a filter for a non-existent tenant returns no match.""" + tree.insert("apple", "tenant_1", 1) + matched_text, matched_tenants = tree.prefix_match( + "application", ["non_existent_tenant"] + ) + assert matched_text == "" + assert matched_tenants is None + + +class TestPrefixTreeRemove: + def test_remove_single_leaf_node_pruned(self, tree: PrefixTree) -> None: + """Test _remove_tenant_single_node for a leaf node; node should be pruned.""" + tree.insert("hello", "tenant_1", 1) + hello_node = tree.root.edge_label_to_child["h"] + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1} + assert tree.tenant_to_char_count == {"tenant_1": 5} + assert tree.root.edge_label_to_child == {"h": hello_node} + + removed_chars = tree._remove_tenant_single_node("tenant_1", hello_node) + assert removed_chars == 5 + assert hello_node.tenant_to_last_access_time == {} + assert tree.tenant_to_char_count == {"tenant_1": 0} + assert tree.root.edge_label_to_child == {} # Node pruned + + def test_remove_single_leaf_node_not_pruned(self, tree: PrefixTree) -> None: + """Test _remove_tenant_single_node for a leaf node; node should not be pruned.""" + tree.insert("hello", "tenant_1", 1) + tree.insert("hello", "tenant_2", 2) + hello_node = tree.root.edge_label_to_child["h"] + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 2} + assert tree.tenant_to_char_count == {"tenant_1": 5, "tenant_2": 5} + assert tree.root.edge_label_to_child == {"h": hello_node} + + removed_chars = tree._remove_tenant_single_node("tenant_1", hello_node) + assert removed_chars == 5 + assert hello_node.tenant_to_last_access_time == {"tenant_2": 2} + assert tree.tenant_to_char_count == {"tenant_1": 0, "tenant_2": 5} + assert tree.root.edge_label_to_child == {"h": hello_node} # Node not pruned + + def test_remove_single_node_with_non_existent_tenant( + self, tree: PrefixTree + ) -> None: + """Test _remove_tenant_single_node for a non-existent tenant is a no-op.""" + tree.insert("hello", "tenant_1", 1) + hello_node = tree.root.edge_label_to_child["h"] + removed_chars = tree._remove_tenant_single_node( + "non_existent_tenant", hello_node + ) + assert removed_chars == 0 + + def test_remove_single_node_with_non_matching_tenant( + self, tree: PrefixTree + ) -> None: + """Test _remove_tenant_single_node if node doesn't belong to specified tenant is a no-op.""" + tree.insert("hello", "tenant_1", 1) + tree.insert("world", "tenant_2", 2) # Node for tenant_2 + hello_node = tree.root.edge_label_to_child["h"] # Belongs to tenant_1 + removed_chars = tree._remove_tenant_single_node( + "tenant_2", hello_node + ) # Try removing tenant_2 from tenant_1's node + assert removed_chars == 0 + + def test_remove_tenant(self, tree: PrefixTree) -> None: + """Test remove_tenant for a tree with multiple tenants only removes the specified tenant.""" + tree.insert("hello", "tenant_1", 1) + tree.insert("foobar", "tenant_1", 2) + tree.insert("helloworld", "tenant_2", 3) + removed_chars = tree.remove_tenant("tenant_1") + assert removed_chars == 11 + hello_node = tree.root.edge_label_to_child["h"] + assert hello_node.tenant_to_last_access_time == {"tenant_2": 3} + assert tree.tenant_to_char_count == {"tenant_2": 10} + assert set(tree.tenant_to_lru_tail.keys()) == {"tenant_2"} + tenant_2_lru_texts = get_lru_texts_from_tree(tree, "tenant_2") + assert tenant_2_lru_texts == ["", "world", "hello"] + + def test_remove_non_existent_tenant(self, tree: PrefixTree) -> None: + """Test remove_tenant for a non-existent tenant returns 0.""" + tree.insert("hello", "tenant_1", 1) + removed_chars = tree.remove_tenant("non_existent_tenant") + assert removed_chars == 0 + + def test_remove_tenant_prunes_nodes(self, tree: PrefixTree) -> None: + """Test remove_tenant prunes nodes that become tenant-less and childless.""" + tree.insert("helloworld", "tenant_1", 1) # Creates "helloworld" + tree.insert( + "hellothere", "tenant_2", 2 + ) # Splits into "hello" -> "world" and "hello" -> "there" + + tree.remove_tenant( + "tenant_1" + ) # "world" node should be pruned. "hello" and "there" remain for tenant_2. + + hello_node = tree.root.edge_label_to_child["h"] + assert set(hello_node.edge_label_to_child.keys()) == { + "t" + } # "w" (world) child is gone + assert hello_node.edge_label_to_child["t"].text == "there" + assert hello_node.edge_label_to_child["t"].tenant_to_last_access_time == { + "tenant_2": 2 + } + + +class TestPrefixTreeEviction: + def test_eviction_non_existent_tenant(self, tree: PrefixTree) -> None: + """Test evict_tenant_by_lru for a non-existent tenant returns 0.""" + assert tree.evict_tenant_by_lru("nonexistent_tenant", 5) == 0 + + def test_eviction_exact_min_remove_size_single_node(self, tree: PrefixTree) -> None: + """Test evicting exactly min_remove_size characters from a single oldest node.""" + tree.insert("a", "tenant_1", 1) # Oldest (1 char) + tree.insert("bb", "tenant_1", 2) + tree.insert("ccc", "tenant_1", 3) + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "ccc", "bb", "a"] + + evicted_count = tree.evict_tenant_by_lru("tenant_1", 1) # Evict "a" + assert evicted_count == 1 + assert tree.tenant_to_char_count == {"tenant_1": 5} # 6 - 1 + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "ccc", "bb"] + + def test_eviction_exceed_min_remove_size_single_node( + self, tree: PrefixTree + ) -> None: + """Test evicting more than min_remove_size characters from a single oldest node.""" + tree.insert("aaa", "tenant_1", 1) # Oldest (2 chars) + tree.insert("bb", "tenant_1", 2) + tree.insert("c", "tenant_1", 3) + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "c", "bb", "aaa"] + + evicted_count = tree.evict_tenant_by_lru("tenant_1", 1) # Evict "aaa" + assert evicted_count == 3 + assert tree.tenant_to_char_count == {"tenant_1": 3} # 6 - 3 + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "c", "bb"] + + def test_eviction_multiple_nodes(self, tree: PrefixTree) -> None: + """Test evicting multiple oldest nodes to meet min_remove_size.""" + tree.insert("a", "tenant_1", 1) # Oldest (1 char) + tree.insert("bb", "tenant_1", 2) # Next oldest (2 chars) + tree.insert("ccc", "tenant_1", 3) + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "ccc", "bb", "a"] + + evicted_count = tree.evict_tenant_by_lru("tenant_1", 2) # Evict "a" and "b" + assert evicted_count == 3 # 1 ("a") + 2 ("b") + assert tree.tenant_to_char_count["tenant_1"] == 3 # 6 - 3 + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "ccc"] + + def test_eviction_same_timestamps(self, tree: PrefixTree) -> None: + """Test evicting more than min_remove_size if multiple nodes share the oldest timestamp.""" + tree.insert("helloworld", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello", "world"] + assert get_lru_texts_from_tree(tree, "tenant_2") == ["", "there", "hello"] + + # Should remove both "hello" and "world" because they have the same timestamp + evicted_count = tree.evict_tenant_by_lru("tenant_1", 1) # Request 1 char + assert evicted_count == 10 # Removes "hello" and "world" + assert tree.tenant_to_char_count == {"tenant_1": 0, "tenant_2": 10} + assert get_lru_texts_from_tree(tree, "tenant_1") == [""] + assert get_lru_texts_from_tree(tree, "tenant_2") == ["", "there", "hello"] + + def test_eviction_insufficient_chars_evicts_all(self, tree: PrefixTree) -> None: + """Test evicting when min_remove_size is larger than available; evicts all.""" + tree.insert("xyz", "tenant_1", 1) # 3 chars available + evicted_count = tree.evict_tenant_by_lru("tenant_1", 10) + assert evicted_count == 3 + assert tree.tenant_to_char_count == {"tenant_1": 0} + assert get_lru_texts_from_tree(tree, "tenant_1") == [""] + + +class TestPrefixTreeGetSmallestTenant: + def test_get_smallest_tenant(self, tree: PrefixTree) -> None: + """Test get_smallest_tenant identifies the tenant with the fewest characters.""" + tree.insert("aaaa", "tenant_1", 1) # 4 chars + tree.insert("bb", "tenant_2", 2) # 2 chars + tree.insert("c", "tenant_3", 3) # 1 char + assert tree.get_smallest_tenant() == "tenant_3" + + def test_get_smallest_tenant_empty_tree(self, tree: PrefixTree) -> None: + """Test get_smallest_tenant on an empty tree returns None.""" + assert tree.get_smallest_tenant() is None + + def test_get_smallest_tenant_after_update(self, tree: PrefixTree) -> None: + """Test get_smallest_tenant after removing the current smallest tenant.""" + tree.insert("aaaa", "tenant_1", 1) + tree.insert("bb", "tenant_2", 2) + tree.insert("c", "tenant_3", 3) + tree.remove_tenant("tenant_3") # Remove "c" (1 char) + assert ( + tree.get_smallest_tenant() == "tenant_2" + ) # "bb" (2 chars) is now smallest + + +class TestPrefixTreeComprehensive: + """Comprehensive tests for the PrefixTree""" + + def test_tree_structure_multiple_insertions(self, tree: PrefixTree) -> None: + """Test tree structure after multiple insertions.""" + tree.insert("helloworld", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) + tree.insert("hellothomas", "tenant_2", 3) + + # Access tree directly + root: Node = tree.root + + # Test tree structure - validate each node + # Root node + assert root.text == "" + assert root.parent is None + assert root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert set(root.edge_label_to_child.keys()) == {"h"} + + # Hello node + hello_node: Node = root.edge_label_to_child["h"] + assert hello_node.text == "hello" + assert hello_node.parent.text == "" + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert set(hello_node.edge_label_to_child.keys()) == {"w", "t"} + + # World node + world_node: Node = hello_node.edge_label_to_child["w"] + assert world_node.text == "world" + assert world_node.parent.text == "hello" + assert world_node.tenant_to_last_access_time == {"tenant_1": 1} + assert set(world_node.edge_label_to_child.keys()) == set() + + # Th node + th_node: Node = hello_node.edge_label_to_child["t"] + assert th_node.text == "th" + assert th_node.parent.text == "hello" + assert th_node.tenant_to_last_access_time == {"tenant_2": 3} + assert set(th_node.edge_label_to_child.keys()) == {"e", "o"} + + # Ere node + ere_node: Node = th_node.edge_label_to_child["e"] + assert ere_node.text == "ere" + assert ere_node.parent.text == "th" + assert ere_node.tenant_to_last_access_time == {"tenant_2": 2} + assert set(ere_node.edge_label_to_child.keys()) == set() + + # Omas node + omas_node: Node = th_node.edge_label_to_child["o"] + assert omas_node.text == "omas" + assert omas_node.parent.text == "th" + assert omas_node.tenant_to_last_access_time == {"tenant_2": 3} + assert set(omas_node.edge_label_to_child.keys()) == set() + + def test_multiple_evictions_maintains_lru_order(self, tree: PrefixTree) -> None: + """Test multiple evictions maintain LRU order.""" + tree.insert("helloworld", "tenant_1", 1) + tree.insert("hellothere", "tenant_2", 2) + tree.insert("hellothomas", "tenant_2", 3) + assert tree.tenant_to_char_count == {"tenant_1": 10, "tenant_2": 14} + assert get_lru_texts_from_tree(tree, "tenant_1") == ["", "hello", "world"] + assert get_lru_texts_from_tree(tree, "tenant_2") == [ + "", + "omas", + "th", + "hello", + "ere", + ] + + # Eviction 1 (tenant_1): min_remove_size=1. "hello" and "world" removed. + evicted_1 = tree.evict_tenant_by_lru("tenant_1", 1) + assert evicted_1 == 10 + assert tree.tenant_to_char_count == {"tenant_1": 0, "tenant_2": 14} + assert get_lru_texts_from_tree(tree, "tenant_1") == [""] + assert get_lru_texts_from_tree(tree, "tenant_2") == [ + "", + "omas", + "th", + "hello", + "ere", + ] # T2 unchanged + + # Eviction 2 (tenant_2): min_remove_size=1. "ere" is oldest timestamp, removed. + evicted_2 = tree.evict_tenant_by_lru("tenant_2", 1) + assert evicted_2 == 3 # "ere" is 3 chars + assert tree.tenant_to_char_count == {"tenant_1": 0, "tenant_2": 11} # 14 - 3 + assert get_lru_texts_from_tree(tree, "tenant_2") == ["", "omas", "th", "hello"] + + # Eviction 3 (tenant_2): min_remove_size=1. "omas"(ts3), "th"(ts3), "hello"(ts3) removed. + evicted_3 = tree.evict_tenant_by_lru("tenant_2", 1) + assert evicted_3 == 11 # 4+2+5 chars + assert tree.tenant_to_char_count == {"tenant_1": 0, "tenant_2": 0} + assert get_lru_texts_from_tree(tree, "tenant_2") == [""] -# PrefixTreeActor tests @pytest.mark.asyncio -async def test_tree_actor(tree_actor) -> None: - """Test the PrefixTreeActor.""" - # 1. Test tree structure and LRU heap ordering - tree_actor._reset.remote() - - # Insert strings in specified order - tree_actor.insert.remote("helloworld", "tenant_1", 1) # time 1 for tenant_1 - tree_actor.insert.remote("hellothere", "tenant_2", 2) # time 2 for tenant_2 - tree_actor.insert.remote("hellothomas", "tenant_2", 3) # time 3 for tenant_2 - - # Access tree directly - tree_rep: Dict = ray.get(tree_actor._to_dict.remote()) - root: Node = tree_rep["root"] - - # Test tree structure - validate each node - # Root node - assert root.text == "" - assert root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} - assert "h" in root.edge_label_to_child - - # Hello node - hello_node: Node = root.edge_label_to_child["h"] - assert hello_node.text == "hello" - assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} - assert "w" in hello_node.edge_label_to_child - assert "t" in hello_node.edge_label_to_child - - # World node - world_node: Node = hello_node.edge_label_to_child["w"] - assert world_node.text == "world" - assert world_node.tenant_to_last_access_time == {"tenant_1": 1} - assert len(world_node.edge_label_to_child) == 0 - - # Th node - th_node: Node = hello_node.edge_label_to_child["t"] - assert th_node.text == "th" - assert th_node.tenant_to_last_access_time == {"tenant_2": 3} - assert "e" in th_node.edge_label_to_child - assert "o" in th_node.edge_label_to_child - - # Ere node - ere_node: Node = th_node.edge_label_to_child["e"] - assert ere_node.text == "ere" - assert ere_node.tenant_to_last_access_time == {"tenant_2": 2} - assert len(ere_node.edge_label_to_child) == 0 - - # Omas node - omas_node: Node = th_node.edge_label_to_child["o"] - assert omas_node.text == "omas" - assert omas_node.tenant_to_last_access_time == {"tenant_2": 3} - assert len(omas_node.edge_label_to_child) == 0 - - # Test PrefixTree instance variables - # Using tenant_to_nodes instead of tenants - assert set(tree_rep["tenant_to_nodes"].keys()) == {"tenant_1", "tenant_2"} - - # Test tenant_to_nodes (check by text) - tenant1_nodes_texts: Set[str] = { - node.text for node in tree_rep["tenant_to_nodes"]["tenant_1"] - } - assert tenant1_nodes_texts == {"", "hello", "world"} - - tenant2_nodes_texts: Set[str] = { - node.text for node in tree_rep["tenant_to_nodes"]["tenant_2"] - } - assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} - - # Test tenant_to_char_count - # Before evictions - assert ( - tree_rep["tenant_to_char_count"]["tenant_1"] == 10 - ) # root(0) + hello(5) + world(5) = 10 - assert ( - tree_rep["tenant_to_char_count"]["tenant_2"] == 14 - ) # root(0) + hello(5) + th(2) + ere(3) + omas(4) = 14 - - # After evicting tenant_1 with min_remove_size=1 - # Should remove both "hello" and "world" nodes (10 chars) since they have the same timestamp - evicted_count = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_1", 1)) - assert evicted_count == 10 # All 10 chars removed, not just 1 - tree_rep = ray.get(tree_actor._to_dict.remote()) - assert tree_rep["tenant_to_char_count"]["tenant_1"] == 0 - - # After evicting tenant_2 with min_remove_size=1 - # Should remove "ere" node (3 chars) since it has the oldest timestamp (2) - evicted_count = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_2", 1)) - assert evicted_count == 3 # All 3 chars from "ere" removed - - tree_rep = ray.get(tree_actor._to_dict.remote()) - assert tree_rep["tenant_to_char_count"]["tenant_2"] == 11 # 14 - 3 = 11 - - # After evicting tenant_2 again with min_remove_size=1 - # Should remove "hello", "th", and "omas" nodes (11 chars) since they all have timestamp 3 - evicted_count = ray.get(tree_actor.evict_tenant_by_lru.remote("tenant_2", 1)) - assert evicted_count == 11 # All 11 remaining chars removed - - tree_rep = ray.get(tree_actor._to_dict.remote()) - assert tree_rep["tenant_to_char_count"]["tenant_2"] == 0 - - -# PrefixTree tests -def test__add_tenant(tree: PrefixTree) -> None: - """Test adding tenants to the tree via the private _add_tenant method.""" - # 1. Test basic tenant addition - tree._reset() - tree._add_tenant("tenant_1") - assert "tenant_1" in tree.tenant_to_nodes - assert tree.tenant_to_char_count["tenant_1"] == 0 - assert tree.tenant_to_nodes["tenant_1"] == set() - - # 2. Test adding duplicate tenant logs warning but doesn't raise error - tree._reset() - tree._add_tenant("tenant_1") - # This should be a no-op - tree._add_tenant("tenant_1") - # Verify the tenant still exists - assert "tenant_1" in tree.tenant_to_nodes - - -def test_insert(tree: PrefixTree) -> None: - """Test the insert functionality of PrefixTree.""" - # 1. Test basic insertion - tree._reset() - # No need to call add_tenant first - insert will do it automatically - tree.insert("hello", "tenant_1", 1) - matched_text, matched_tenants = tree.prefix_match("hello") - assert matched_text == "hello" and matched_tenants == ["tenant_1"] - assert tree.tenant_to_char_count["tenant_1"] == 5 - assert len(tree.tenant_to_nodes["tenant_1"]) == 2 - - # 2. Test duplicate insertion doesn't double count - tree._reset() - tree.insert("foo", "tenant_1", 1) - tree.insert("foo", "tenant_1", 1) # duplicate - tree.insert("bar", "tenant_2", 2) - assert ( - tree.tenant_to_char_count["tenant_1"] == 3 - and tree.tenant_to_char_count["tenant_2"] == 3 - ) - - # 3. Test node splitting on partial match - tree._reset() - tree.insert("helloworld", "tenant_1", 1) - tree.insert("hellothere", "tenant_2", 2) - - root: Node = tree.root - h_node: Optional[Node] = root.edge_label_to_child.get("h") - assert h_node is not None and h_node.text == "hello" - assert h_node.edge_label_to_child.get("w").text == "world" - assert h_node.edge_label_to_child.get("t").text == "there" - - # 4. Test that inserting a longer prompt with shared prefix doesn't create empty text nodes - tree._reset() - tree.insert("hello", "tenant_1", 1) - tree.insert("helloworld", "tenant_2", 2) - - root = tree.root - - # Check that only the root has empty text by directly traversing the tree - # Starting from root, collect all nodes with empty text - empty_text_nodes: List[Node] = [] - nodes_to_check: List[Node] = [root] - - while nodes_to_check: - node: Node = nodes_to_check.pop() - if node.text == "": - empty_text_nodes.append(node) - # Add all children to check - nodes_to_check.extend(node.edge_label_to_child.values()) - - # There should be exactly one empty text node (the root) - assert len(empty_text_nodes) == 1 and root in empty_text_nodes - - # Verify tree structure - h_node = root.edge_label_to_child.get("h") - assert h_node is not None and h_node.text == "hello" - assert ( - "tenant_1" in h_node.tenant_to_last_access_time - and "tenant_2" in h_node.tenant_to_last_access_time - ) - - # Verify "world" node belongs only to tenant 2 - world_node: Optional[Node] = h_node.edge_label_to_child.get("w") - assert world_node is not None and world_node.text == "world" - assert ( - "tenant_2" in world_node.tenant_to_last_access_time - and "tenant_1" not in world_node.tenant_to_last_access_time - ) - - # Verify the only child of h_node is "w" - assert len(h_node.edge_label_to_child) == 1 - - -def test_prefix_match(tree: PrefixTree) -> None: - """Test the prefix_match functionality of PrefixTree.""" - # 1. Test no match - tree._reset() - matched_text, matched_tenants = tree.prefix_match("hello") - assert matched_text == "" and matched_tenants is None - - # 2. Test match with non-existing prefix returns empty string and all tenants - tree._reset() - tree.insert("hello", "tenant_1", 1) - tree.insert("hellothere", "tenant_2", 2) - matched_text, matched_tenants = tree.prefix_match("foobar") - assert matched_text == "" and matched_tenants == ["tenant_1", "tenant_2"] - - # 3. Test exact match - tree._reset() - tree.insert("hello", "tenant_1", 1) - matched_text, matched_tenants = tree.prefix_match("hello") - assert matched_text == "hello" and matched_tenants == ["tenant_1"] - - # 4. Test partial match - tree._reset() - tree.insert("apple", "tenant_1", 1) - tree.insert("apricot", "tenant_2", 2) - matched_text, matched_tenants = tree.prefix_match("application") - assert matched_text == "appl" and matched_tenants == ["tenant_1"] - - # 5. Test match by tenant - tree._reset() - tree.insert("apple", "tenant_1", 1) - tree.insert("apricot", "tenant_2", 2) - matched_text, matched_tenants = tree.prefix_match("application", ["tenant_2"]) - assert matched_text == "ap" and matched_tenants == ["tenant_2"] - - # 6. Test match by non-existent tenant - tree._reset() - tree.insert("apple", "tenant_1", 1) - tree.insert("apricot", "tenant_2", 2) - matched_text, matched_tenants = tree.prefix_match("application", ["tenant_3"]) - assert matched_text == "" and matched_tenants is None - - # 7. Test shared prefix matching with branches - tree._reset() - tree.insert("helloworld", "tenant_1", 1) - tree.insert("hellothere", "tenant_2", 2) - - matched_text, matched_tenants = tree.prefix_match("helloworld") - assert matched_text == "helloworld" and matched_tenants == ["tenant_1"] - - matched_text, matched_tenants = tree.prefix_match("hellothereworld") - assert matched_text == "hellothere" and matched_tenants == ["tenant_2"] - - -def test__remove_tenant_single_node(tree: PrefixTree) -> None: - """Test removing a single node for a tenant.""" - # 1. Test removing a single node - tree._reset() - tree.insert("hello", "tenant_1", 1) - h_node: Node = tree.root.edge_label_to_child["h"] - - removed: int = tree._remove_tenant_single_node("tenant_1", h_node) - assert removed == 5 - assert tree.tenant_to_char_count["tenant_1"] == 0 - assert ( - len(tree.tenant_to_nodes["tenant_1"]) == 1 - and tree.root in tree.tenant_to_nodes["tenant_1"] - ) - - # 2. Test removing node for non-existent tenant is idempotent - tree._reset() - tree.insert("hello", "tenant_1", 1) - root: Node = tree.root - h_node: Optional[Node] = root.edge_label_to_child.get("h") - - # Should not raise error, just return 0 - removed = tree._remove_tenant_single_node("nonexistent_tenant", h_node) - assert removed == 0 - - # 3. Test removing node that doesn't belong to tenant is idempotent - tree._reset() - tree.insert("hello", "tenant_1", 1) - tree.insert("world", "tenant_2", 2) - - root = tree.root - h_node = root.edge_label_to_child.get("h") - - # Should not raise error, just return 0 - removed = tree._remove_tenant_single_node("tenant_2", h_node) - assert removed == 0 - - -def test_remove_tenant(tree: PrefixTree) -> None: - """Test removing a tenant from the tree.""" - # 1. Test basic tenant removal - tree._reset() - tree.insert("hello", "tenant_1", 1) - removed: int = tree.remove_tenant("tenant_1") - assert removed == 5 - assert ( - "tenant_1" not in tree.tenant_to_nodes - and "tenant_1" not in tree.tenant_to_char_count - ) - - # 2. Test removing tenant with multiple nodes - tree._reset() - tree.insert("cat", "tenant_1", 1) - tree.insert("dog", "tenant_1", 2) - removed = tree.remove_tenant("tenant_1") - assert removed == len("cat") + len("dog") - - # 3. Test removing non-existent tenant is idempotent (logs warning, returns 0) - tree._reset() - # Should not raise error, just return 0 - removed = tree.remove_tenant("nonexistent_tenant") - assert removed == 0 - - # 4. Test tree structure after removing tenant - tree._reset() - tree.insert("hello", "tenant_1", 1) - tree.insert("hello", "tenant_2", 2) - - # Remove tenant_1, verify tenant_2 still works - tree.remove_tenant("tenant_1") - assert "tenant_1" not in tree.tenant_to_nodes and "tenant_2" in tree.tenant_to_nodes - - matched_text, matched_tenants = tree.prefix_match("hello") - assert matched_text == "hello" and matched_tenants == ["tenant_2"] - - # 5. Test removing the last tenant from a node removes the node - tree._reset() - tree.insert("helloworld", "tenant_1", 1) - tree.insert("hellothere", "tenant_2", 2) - - # Remove tenant_1 - tree.remove_tenant("tenant_1") - - root: Node = tree.root - # 'h' node should only have one child now ('t' from hellothere) - assert "h" in root.edge_label_to_child - assert "t" in root.edge_label_to_child["h"].edge_label_to_child - assert len(root.edge_label_to_child["h"].edge_label_to_child) == 1 - - -def test_evict_tenant_by_lru(tree: PrefixTree) -> None: - """Test the evict_tenant_by_lru functionality of PrefixTree.""" - - # 1. Remove exactly min_remove_size characters - tree._reset() - tree.insert("a", "tenant_1", 1) - tree.insert("bb", "tenant_1", 2) - tree.insert("ccc", "tenant_1", 3) - - # Before eviction - char_count_before: int = tree.tenant_to_char_count["tenant_1"] - assert ( - len(tree.tenant_to_nodes["tenant_1"]) == 4 - and tree.tenant_to_char_count["tenant_1"] == 6 - ) - - # During eviction - min_remove_size: int = 1 - evicted_count: int = tree.evict_tenant_by_lru("tenant_1", min_remove_size) - - # After eviction - char_count_after: int = tree.tenant_to_char_count["tenant_1"] - assert evicted_count == min_remove_size - assert char_count_before - char_count_after == evicted_count - assert ( - len(tree.tenant_to_nodes["tenant_1"]) == 3 - and tree.tenant_to_char_count["tenant_1"] == 5 - ) - - # 2. Remove more than min_remove_size characters - tree._reset() - tree.insert("a", "tenant_1", 1) - tree.insert("bb", "tenant_1", 2) - tree.insert("ccc", "tenant_1", 3) - - # Before eviction - char_count_before = tree.tenant_to_char_count["tenant_1"] - assert ( - len(tree.tenant_to_nodes["tenant_1"]) == 4 - and tree.tenant_to_char_count["tenant_1"] == 6 - ) - - # During eviction - min_remove_size = 2 - evicted_count = tree.evict_tenant_by_lru("tenant_1", min_remove_size) - - # After eviction - char_count_after = tree.tenant_to_char_count["tenant_1"] - assert evicted_count != min_remove_size and evicted_count == 3 - assert char_count_before - char_count_after == evicted_count - assert ( - len(tree.tenant_to_nodes["tenant_1"]) == 2 - and tree.tenant_to_char_count["tenant_1"] == 3 - ) - - # 3. Test eviction of non-existent tenant is idempotent - tree._reset() - # Should not raise error, just return 0 - evicted_count = tree.evict_tenant_by_lru("nonexistent_tenant", 5) - assert evicted_count == 0 - - # 4. Test eviction of tenant with insufficient characters is idempotent - tree._reset() - tree.insert("xyz", "tenant_1", 1) - # Should not raise error, should evict all available characters - evicted_count = tree.evict_tenant_by_lru("tenant_1", 4) - assert evicted_count == 3 # "xyz" has 3 characters - - # 5. Test eviction of all tenant data - tree._reset() - tree.insert("xyz", "tenant_1", 1) - - total_size: int = tree.tenant_to_char_count["tenant_1"] - evicted_count = tree.evict_tenant_by_lru("tenant_1", total_size) - assert evicted_count == total_size - # "tenant_1" should still be in tenant_to_nodes - assert "tenant_1" in tree.tenant_to_nodes - - # 6. Test tree structure and LRU eviction - tree._reset() - - # Insert strings in specified order - tree.insert("helloworld", "tenant_1", 1) # time 1 for tenant_1 - tree.insert("hellothere", "tenant_2", 2) # time 2 for tenant_2 - tree.insert("hellothomas", "tenant_2", 3) # time 3 for tenant_2 - - # Access tree directly - root: Node = tree.root - - # Test tree structure - validate each node - # Root node - assert root.text == "" and root.tenant_to_last_access_time == { - "tenant_1": 1, - "tenant_2": 3, - } - assert "h" in root.edge_label_to_child - - # Hello node - hello_node: Node = root.edge_label_to_child["h"] - assert hello_node.text == "hello" and hello_node.tenant_to_last_access_time == { - "tenant_1": 1, - "tenant_2": 3, - } - assert ( - "w" in hello_node.edge_label_to_child and "t" in hello_node.edge_label_to_child - ) - - # World node - world_node: Node = hello_node.edge_label_to_child["w"] - assert world_node.text == "world" and world_node.tenant_to_last_access_time == { - "tenant_1": 1 - } - assert len(world_node.edge_label_to_child) == 0 - - # Th node - th_node: Node = hello_node.edge_label_to_child["t"] - assert th_node.text == "th" and th_node.tenant_to_last_access_time == { - "tenant_2": 3 - } - assert "e" in th_node.edge_label_to_child and "o" in th_node.edge_label_to_child - - # Ere node - ere_node: Node = th_node.edge_label_to_child["e"] - assert ere_node.text == "ere" and ere_node.tenant_to_last_access_time == { - "tenant_2": 2 - } - assert len(ere_node.edge_label_to_child) == 0 - - # Omas node - omas_node: Node = th_node.edge_label_to_child["o"] - assert omas_node.text == "omas" and omas_node.tenant_to_last_access_time == { - "tenant_2": 3 - } - assert len(omas_node.edge_label_to_child) == 0 - - # Test PrefixTree instance variables - assert set(tree.tenant_to_nodes.keys()) == {"tenant_1", "tenant_2"} - - # Test tenant_to_nodes (check by text) - tenant1_nodes_texts: Set[str] = { - node.text for node in tree.tenant_to_nodes["tenant_1"] - } - assert tenant1_nodes_texts == {"", "hello", "world"} - - tenant2_nodes_texts: Set[str] = { - node.text for node in tree.tenant_to_nodes["tenant_2"] - } - assert tenant2_nodes_texts == {"", "hello", "th", "ere", "omas"} - - # Test tenant_to_char_count - # Before evictions - assert ( - tree.tenant_to_char_count["tenant_1"] == 10 - and tree.tenant_to_char_count["tenant_2"] == 14 - ) - - # After evicting tenant_1 with min_remove_size=1 - # Should remove both "hello" and "world" nodes (10 chars) since they have the same timestamp - evicted_count = tree.evict_tenant_by_lru("tenant_1", 1) - assert evicted_count == 10 and tree.tenant_to_char_count["tenant_1"] == 0 - - # After evicting tenant_2 with min_remove_size=1 - # Should remove "ere" node (3 chars) since it has the oldest timestamp (2) - evicted_count = tree.evict_tenant_by_lru("tenant_2", 1) - assert ( - evicted_count == 3 and tree.tenant_to_char_count["tenant_2"] == 11 - ) # 14 - 3 = 11 - - # After evicting tenant_2 again with min_remove_size=1 - # Should remove "hello", "th", and "omas" nodes (11 chars) since they all have timestamp 3 - evicted_count = tree.evict_tenant_by_lru("tenant_2", 1) - assert evicted_count == 11 and tree.tenant_to_char_count["tenant_2"] == 0 - - -def test_get_smallest_tenant(tree: PrefixTree) -> None: - """Test the get_smallest_tenant functionality of PrefixTree.""" - # 1. Test with empty tree - tree._reset() - smallest: Optional[str] = tree.get_smallest_tenant() - assert smallest is None - - # 2. Test with multiple tenants of different sizes - tree._reset() - tree.insert("aaaa", "tenant_1", 1) - tree.insert("bb", "tenant_2", 2) - tree.insert("c", "tenant_3", 3) - - smallest = tree.get_smallest_tenant() - assert smallest == "tenant_3" - - # 3. Test after removing the smallest tenant - tree._reset() - tree.insert("aaaa", "tenant_1", 1) - tree.insert("bb", "tenant_2", 2) - tree.insert("c", "tenant_3", 3) - tree.remove_tenant("tenant_3") - smallest = tree.get_smallest_tenant() - assert smallest == "tenant_2" +class TestPrefixTreeActorComprehensive: + """Comprehensive tests for the PrefixTreeActor""" + + async def test_tree_structure_multiple_insertions_actor( + self, tree_actor: PrefixTreeActor + ) -> None: + # Insert strings in specified order + tree_actor.insert.remote("helloworld", "tenant_1", 1) + tree_actor.insert.remote("hellothere", "tenant_2", 2) + tree_actor.insert.remote("hellothomas", "tenant_2", 3) + assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_1") == [ + "", + "hello", + "world", + ] + + # Access tree directly + root: Node = ray.get(tree_actor.getattr.remote("root")) + + # Test tree structure - validate each node + # Root node + assert root.text == "" + assert root.parent is None + assert root.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert set(root.edge_label_to_child.keys()) == {"h"} + + # Hello node + hello_node: Node = root.edge_label_to_child["h"] + assert hello_node.text == "hello" + assert hello_node.parent.text == "" + assert hello_node.tenant_to_last_access_time == {"tenant_1": 1, "tenant_2": 3} + assert set(hello_node.edge_label_to_child.keys()) == {"w", "t"} + + # World node + world_node: Node = hello_node.edge_label_to_child["w"] + assert world_node.text == "world" + assert world_node.parent.text == "hello" + assert world_node.tenant_to_last_access_time == {"tenant_1": 1} + assert set(world_node.edge_label_to_child.keys()) == set() + + # Th node + th_node: Node = hello_node.edge_label_to_child["t"] + assert th_node.text == "th" + assert th_node.parent.text == "hello" + assert th_node.tenant_to_last_access_time == {"tenant_2": 3} + assert set(th_node.edge_label_to_child.keys()) == {"e", "o"} + + # Ere node + ere_node: Node = th_node.edge_label_to_child["e"] + assert ere_node.text == "ere" + assert ere_node.parent.text == "th" + assert ere_node.tenant_to_last_access_time == {"tenant_2": 2} + assert set(ere_node.edge_label_to_child.keys()) == set() + + # Omas node + omas_node: Node = th_node.edge_label_to_child["o"] + assert omas_node.text == "omas" + assert omas_node.parent.text == "th" + assert omas_node.tenant_to_last_access_time == {"tenant_2": 3} + assert set(omas_node.edge_label_to_child.keys()) == set() + + async def test_multiple_evictions_maintains_lru_order_actor( + self, tree_actor: PrefixTreeActor + ) -> None: + """Test multiple evictions maintain LRU order.""" + tree_actor.insert.remote("helloworld", "tenant_1", 1) + tree_actor.insert.remote("hellothere", "tenant_2", 2) + tree_actor.insert.remote("hellothomas", "tenant_2", 3) + assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { + "tenant_1": 10, + "tenant_2": 14, + } + assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_1") == [ + "", + "hello", + "world", + ] + assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_2") == [ + "", + "omas", + "th", + "hello", + "ere", + ] + + # Eviction 1 (tenant_1): min_remove_size=1. "hello" and "world" removed. + evicted_1 = await tree_actor.evict_tenant_by_lru.remote("tenant_1", 1) + assert evicted_1 == 10 + assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { + "tenant_1": 0, + "tenant_2": 14, + } + assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_1") == [""] + assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_2") == [ + "", + "omas", + "th", + "hello", + "ere", + ] # T2 unchanged + + # Eviction 2 (tenant_2): min_remove_size=1. "ere" is oldest timestamp, removed. + evicted_2 = await tree_actor.evict_tenant_by_lru.remote("tenant_2", 1) + assert evicted_2 == 3 # "ere" is 3 chars + assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { + "tenant_1": 0, + "tenant_2": 11, + } # 14 - 3 + assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_2") == [ + "", + "omas", + "th", + "hello", + ] + + # Eviction 3 (tenant_2): min_remove_size=1. "omas"(ts3), "th"(ts3), "hello"(ts3) removed. + evicted_3 = await tree_actor.evict_tenant_by_lru.remote("tenant_2", 1) + assert evicted_3 == 11 # 4+2+5 chars + assert ray.get(tree_actor.getattr.remote("tenant_to_char_count")) == { + "tenant_1": 0, + "tenant_2": 0, + } + assert await get_lru_texts_from_tree_actor(tree_actor, "tenant_2") == [""] if __name__ == "__main__": import sys - sys.exit(pytest.main(["-v", __file__])) + exit_code = pytest.main(["-v", __file__]) + sys.exit(exit_code)