Skip to content

Commit 24690e4

Browse files
authored
Merge pull request #141 from sbak5/sbak/attr_module_pr
Base Attribution Module for a modular attribution pipeline
2 parents e0fa23e + 03fa9f0 commit 24690e4

File tree

2 files changed

+461
-0
lines changed

2 files changed

+461
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
import asyncio
2+
from concurrent.futures import ThreadPoolExecutor
3+
from enum import Enum, auto
4+
from functools import partial
5+
from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, TypeVar, Union
6+
7+
T = TypeVar('T') # Input type
8+
R = TypeVar('R') # Attribution result type
9+
10+
11+
class AttributionState(Enum):
12+
STOP = auto()
13+
CONTINUE = auto()
14+
15+
16+
class NVRxAttribution(Generic[T, R]):
17+
"""A class that implements a three-step attribution process.
18+
This class is designed to be used in a pipeline of attribution modules.
19+
The output of one attribution module can be used as the input to the next attribution module.
20+
21+
This class handles:
22+
1. Input preprocessing - can handle single objects or lists of objects
23+
2. Attribution computation
24+
3. Output handling
25+
"""
26+
27+
# Shared loop for all instances
28+
_shared_loop = None
29+
_loop_lock = asyncio.Lock()
30+
31+
@classmethod
32+
def get_shared_loop(cls):
33+
"""Get or create the shared event loop."""
34+
if cls._shared_loop is None or cls._shared_loop.is_closed():
35+
cls._shared_loop = asyncio.new_event_loop()
36+
asyncio.set_event_loop(cls._shared_loop)
37+
return cls._shared_loop
38+
39+
def __init__(
40+
self,
41+
preprocess_input: Callable[[Union[T, List[T]]], Any],
42+
attribution: Callable[[Any], R],
43+
output_handler: Callable[[R], None],
44+
attribution_kwargs: Optional[Dict[str, Any]] = None,
45+
thread_pool: Optional[ThreadPoolExecutor] = None,
46+
):
47+
"""Initialize the attribution module.
48+
49+
Args:
50+
preprocess_input: Function to preprocess the input data. Can handle single objects or lists.
51+
attribution: Function to perform the attribution computation
52+
output_handler: Function to handle the attribution results
53+
attribution_kwargs: Optional keyword arguments to pass to the attribution function
54+
thread_pool: Optional thread pool for running sync functions
55+
"""
56+
self._preprocess_input = preprocess_input
57+
self._attribution = attribution
58+
self._output_handler = output_handler
59+
self.attribution_kwargs = attribution_kwargs or {}
60+
self._thread_pool = thread_pool or ThreadPoolExecutor(max_workers=2)
61+
62+
# Get the shared loop and set the thread pool
63+
self._loop = self.get_shared_loop()
64+
self._loop.set_default_executor(self._thread_pool)
65+
66+
async def _run_sync_in_thread(self, func: Callable, *args, **kwargs) -> Any:
67+
"""Run a synchronous function in a thread pool.
68+
69+
Args:
70+
func: The synchronous function to run
71+
*args: Positional arguments for the function
72+
**kwargs: Keyword arguments for the function
73+
74+
Returns:
75+
The result of the function
76+
"""
77+
loop = asyncio.get_running_loop()
78+
return await loop.run_in_executor(self._thread_pool, partial(func, *args, **kwargs))
79+
80+
async def _preprocess_input_inner(
81+
self, input_data: Union[T, List[T], Awaitable[Union[T, List[T]]]]
82+
) -> tuple[Any, AttributionState]:
83+
"""Preprocess the input data.
84+
85+
Args:
86+
input_data: The raw input data to be preprocessed. Can be:
87+
- A single object of type T
88+
- A list of objects of type T
89+
- An awaitable that resolves to either of the above
90+
91+
Returns:
92+
Preprocessed data ready for attribution, and a flag to indicate if the attribution should continue.
93+
If the flag is AttributionState.STOP, the attribution should stop and the preprocessed data should be returned.
94+
If the flag is AttributionState.CONTINUE, the attribution should continue.
95+
"""
96+
# Handle awaitable inputs (e.g., from other attribution modules)
97+
# Await on awaitable objects in the list input_data
98+
awaited_input_data = None
99+
if isinstance(input_data, Awaitable):
100+
awaited_input_data = await input_data
101+
102+
if isinstance(input_data, list):
103+
awaited_input_data = []
104+
for item in input_data:
105+
awaited_item = None
106+
if isinstance(item, Awaitable):
107+
awaited_item = await item
108+
if awaited_item[1] == AttributionState.STOP:
109+
return awaited_item[0], awaited_item[1]
110+
else:
111+
awaited_input_data.append(awaited_item[0])
112+
else:
113+
awaited_input_data.append(item)
114+
115+
else:
116+
awaited_input_data = input_data
117+
# Check if preprocess_input is async
118+
if asyncio.iscoroutinefunction(self._preprocess_input):
119+
return await self._preprocess_input(awaited_input_data), AttributionState.CONTINUE
120+
else:
121+
return (
122+
await self._run_sync_in_thread(self._preprocess_input, awaited_input_data),
123+
AttributionState.CONTINUE,
124+
)
125+
126+
async def do_attribution(self, preprocessed_data: Any) -> R:
127+
"""Perform the attribution computation.
128+
129+
Args:
130+
preprocessed_data: The preprocessed input data
131+
132+
Returns:
133+
The attribution results
134+
"""
135+
# Check if attribution is async
136+
if asyncio.iscoroutinefunction(self._attribution):
137+
return await self._attribution(preprocessed_data, **self.attribution_kwargs)
138+
else:
139+
return await self._run_sync_in_thread(
140+
self._attribution, preprocessed_data, **self.attribution_kwargs
141+
)
142+
143+
async def output_handler(self, attribution_result: R) -> R:
144+
"""Handle the attribution results.
145+
146+
Args:
147+
attribution_result: The results from the attribution computation
148+
"""
149+
# Check if output_handler is async
150+
if asyncio.iscoroutinefunction(self._output_handler):
151+
return await self._output_handler(attribution_result)
152+
else:
153+
return await self._run_sync_in_thread(self._output_handler, attribution_result)
154+
155+
async def run(self, input_data: Union[T, List[T], Awaitable[Union[T, List[T]]]]) -> R:
156+
"""Run the complete attribution pipeline.
157+
158+
Args:
159+
input_data: The raw input data to process. Can be:
160+
- A single object of type T
161+
- A list of objects of type T
162+
- An awaitable that resolves to either of the above
163+
164+
Returns:
165+
The attribution results of type R
166+
"""
167+
loop = asyncio.get_running_loop()
168+
169+
async def _run_pipeline():
170+
preprocessed_data, flag_to_proceed = await self._preprocess_input_inner(input_data)
171+
if flag_to_proceed == AttributionState.CONTINUE:
172+
attribution_result = await self.do_attribution(preprocessed_data)
173+
final_output = await self.output_handler(attribution_result)
174+
return final_output
175+
else:
176+
return preprocessed_data
177+
178+
return await loop.create_task(_run_pipeline())
179+
180+
def run_sync(self, input_data: Union[T, List[T], Awaitable[Union[T, List[T]]]]) -> R:
181+
"""Run the attribution pipeline synchronously.
182+
183+
Args:
184+
input_data: The raw input data to process. Can be:
185+
- A single object of type T
186+
- A list of objects of type T
187+
- An awaitable that resolves to either of the above
188+
189+
Returns:
190+
The attribution results of type R
191+
"""
192+
loop = self._loop
193+
194+
try:
195+
return loop.run_until_complete(self.run(input_data))
196+
finally:
197+
# Don't close the shared loop, just clean up if needed
198+
pass
199+
200+
def __del__(self):
201+
"""Cleanup thread pool on deletion."""
202+
if hasattr(self, '_thread_pool'):
203+
self._thread_pool.shutdown(wait=False)

0 commit comments

Comments
 (0)