Skip to content

Commit bf88d78

Browse files
committed
Make code async
1 parent 8d82fd0 commit bf88d78

File tree

11 files changed

+156
-116
lines changed

11 files changed

+156
-116
lines changed

src/galileo/decorator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -930,6 +930,24 @@ def start_session(
930930
name=name, previous_session_id=previous_session_id, external_id=external_id
931931
)
932932

933+
async def async_start_session(
934+
self, name: Optional[str] = None, previous_session_id: Optional[str] = None, external_id: Optional[str] = None
935+
) -> str:
936+
"""
937+
Async start a session in the active context logger instance.
938+
939+
Args:
940+
name: The name of the session. If not provided, a session name will be generated automatically.
941+
previous_session_id: The id of the previous session. Defaults to None.
942+
external_id: The external id of the session. Defaults to None.
943+
944+
Returns:
945+
str: The id of the newly created session.
946+
"""
947+
return await self.get_logger_instance().async_start_session(
948+
name=name, previous_session_id=previous_session_id, external_id=external_id
949+
)
950+
933951
def clear_session(self) -> None:
934952
"""Clear the session in the active context logger instance."""
935953
self.get_logger_instance().clear_session()

src/galileo/logger/logger.py

Lines changed: 41 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,22 +1016,8 @@ def conclude(
10161016

10171017
return current_parent
10181018

1019-
@nop_sync
1020-
def flush(self) -> list[Trace]:
1021-
"""
1022-
Upload all traces to Galileo.
1023-
1024-
Returns:
1025-
-------
1026-
List[Trace]: The list of uploaded traces.
1027-
"""
1028-
if self.mode == "batch":
1029-
return self._flush_batch()
1030-
else:
1031-
self._logger.warning("Flushing in streaming mode is not supported.")
1032-
return list()
1033-
1034-
def _flush_batch(self):
1019+
async def _flush_batch(self, is_async: bool = False) -> list[Trace]:
1020+
# import pdb; pdb.set_trace()
10351021
if not self.traces:
10361022
self._logger.info("No traces to flush.")
10371023
return list()
@@ -1044,7 +1030,7 @@ def _flush_batch(self):
10441030

10451031
if self.local_metrics:
10461032
self._logger.info("Computing local metrics...")
1047-
# TODO: parallelize, possibly with ThreadPoolExecutor
1033+
# TODO: parallelize, possibly with ThreadPoolExecutor/asyncio
10481034
for trace in self.traces:
10491035
populate_local_metrics(trace, self.local_metrics)
10501036

@@ -1053,7 +1039,16 @@ def _flush_batch(self):
10531039
traces_ingest_request = TracesIngestRequest(
10541040
traces=self.traces, experiment_id=self.experiment_id, session_id=self.session_id
10551041
)
1056-
self._client.ingest_traces_sync(traces_ingest_request)
1042+
1043+
if is_async:
1044+
await self._client.ingest_traces(traces_ingest_request)
1045+
else:
1046+
# Use async_run() instead of asyncio.run() to work in all environments
1047+
# (Jupyter notebooks, pytest-asyncio, FastAPI, etc.)
1048+
from galileo_core.helpers.execution import async_run
1049+
1050+
async_run(self._client.ingest_traces(traces_ingest_request))
1051+
10571052
logged_traces = self.traces
10581053

10591054
self._logger.info("Successfully flushed %d traces.", len(logged_traces))
@@ -1072,39 +1067,39 @@ async def async_flush(self) -> list[Trace]:
10721067
List[Trace]: The list of uploaded workflows.
10731068
"""
10741069
if self.mode == "batch":
1075-
return await self._async_flush_batch()
1070+
return await self._flush_batch(is_async=True)
10761071
else:
10771072
self._logger.warning("Flushing in streaming mode is not supported.")
10781073
return list()
10791074

1080-
async def _async_flush_batch(self) -> list[Trace]:
1081-
if not self.traces:
1082-
self._logger.info("No traces to flush.")
1083-
return list()
1084-
1085-
current_parent = self.current_parent()
1086-
if current_parent is not None:
1087-
self._logger.info("Concluding the active trace...")
1088-
last_output = get_last_output(current_parent)
1089-
self.conclude(output=last_output, conclude_all=True)
1090-
1091-
if self.local_metrics:
1092-
self._logger.info("Computing metrics for local scorers...")
1093-
# TODO: parallelize, possibly with asyncio to_thread/gather
1094-
for trace in self.traces:
1095-
populate_local_metrics(trace, self.local_metrics)
1096-
1097-
self._logger.info("Flushing %d traces...", len(self.traces))
1098-
1099-
traces_ingest_request = TracesIngestRequest(traces=self.traces, session_id=self.session_id)
1100-
await self._client.ingest_traces(traces_ingest_request)
1101-
logged_traces = self.traces
1102-
1103-
self._logger.info("Successfully flushed %d traces.", len(logged_traces))
1075+
@nop_sync
1076+
def flush(self) -> list[Trace]:
1077+
"""
1078+
Upload all traces to Galileo.
11041079
1105-
self.traces = list()
1106-
self._parent_stack = deque()
1107-
return logged_traces
1080+
Returns:
1081+
-------
1082+
List[Trace]: The list of uploaded traces.
1083+
"""
1084+
if self.mode == "batch":
1085+
# This is bad because asyncio.run() fails in environments with existing event loops
1086+
# (e.g. jupyter notebooks, FastAPI, etc. would fail with "cannot be called from a running event loop"
1087+
# Even though flush() is sync, it can be called from async contexts like:
1088+
# - Jupyter notebooks (which have their own event loop)
1089+
# - pytest-asyncio tests (where @mark.asyncio creates an event loop)
1090+
# - FastAPI/Django async views (where the web framework has an event loop)
1091+
# - Any async function that calls sync code
1092+
# The EventLoopThreadPool approach works in ALL environments by using dedicated threads
1093+
# return asyncio.run(self._flush_batch(is_async=False))
1094+
1095+
# This is good because async_run() uses EventLoopThreadPool which works in all environments
1096+
# by running async code in dedicated threads with their own event loops
1097+
from galileo_core.helpers.execution import async_run
1098+
1099+
return async_run(self._flush_batch(is_async=False))
1100+
else:
1101+
self._logger.warning("Flushing in streaming mode is not supported.")
1102+
return list()
11081103

11091104
@nop_sync
11101105
def terminate(self) -> None:

src/galileo/utils/core_api_client.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,18 +85,6 @@ async def ingest_traces(self, traces_ingest_request: TracesIngestRequest) -> dic
8585
RequestMethod.POST, endpoint=Routes.traces.format(project_id=self.project_id), json=json
8686
)
8787

88-
def ingest_traces_sync(self, traces_ingest_request: TracesIngestRequest) -> dict[str, str]:
89-
if self.experiment_id:
90-
traces_ingest_request.experiment_id = UUID(self.experiment_id)
91-
elif self.log_stream_id:
92-
traces_ingest_request.log_stream_id = UUID(self.log_stream_id)
93-
94-
json = traces_ingest_request.model_dump(mode="json")
95-
96-
return self._make_request(
97-
RequestMethod.POST, endpoint=Routes.traces.format(project_id=self.project_id), json=json
98-
)
99-
10088
async def ingest_spans(self, spans_ingest_request: SpansIngestRequest) -> dict[str, str]:
10189
if self.experiment_id:
10290
spans_ingest_request.experiment_id = UUID(self.experiment_id)

tests/test_decorator.py

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def llm_call(query: str):
9090

9191
galileo_context.flush()
9292

93-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
93+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
9494

9595
assert len(payload.traces) == 1
9696
assert len(payload.traces[0].spans) == 1
@@ -130,7 +130,7 @@ def llm_call(query: str):
130130

131131
galileo_context.flush(project="project-X", log_stream="log-stream-X")
132132

133-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
133+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
134134

135135
assert len(payload.traces) == 1
136136
assert len(payload.traces[0].spans) == 1
@@ -139,7 +139,7 @@ def llm_call(query: str):
139139

140140
galileo_context.flush(project="project-Y", log_stream="log-stream-Y")
141141

142-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
142+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
143143

144144
assert len(payload.traces) == 1
145145
assert len(payload.traces[0].spans) == 1
@@ -211,7 +211,7 @@ def llm_call(query: str):
211211
llm_call(query="input")
212212
galileo_context.flush()
213213

214-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
214+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
215215

216216
assert len(payload.traces) == 1
217217
assert len(payload.traces[0].spans) == 1
@@ -238,7 +238,7 @@ def my_function(arg1, arg2):
238238
my_function(1, 2)
239239
galileo_context.flush()
240240

241-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
241+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
242242

243243
assert len(payload.traces) == 1
244244
assert len(payload.traces[0].spans) == 1
@@ -267,7 +267,7 @@ def my_function(system: Message, user: Message):
267267
)
268268
galileo_context.flush()
269269

270-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
270+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
271271

272272
assert len(payload.traces) == 1
273273
assert len(payload.traces[0].spans) == 1
@@ -302,7 +302,7 @@ def my_function(system: Message, user: Message):
302302
)
303303
galileo_context.flush()
304304

305-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
305+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
306306

307307
assert len(payload.traces) == 1
308308
assert len(payload.traces[0].spans) == 1
@@ -335,7 +335,7 @@ def my_function(arg1: str, arg2: str):
335335
my_function("arg1", "arg2")
336336
galileo_context.flush()
337337

338-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
338+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
339339

340340
assert len(payload.traces) == 1
341341
assert len(payload.traces[0].spans) == 1
@@ -362,7 +362,7 @@ def my_function(arg1: str, arg2: str):
362362
my_function("arg1", "arg2")
363363
galileo_context.flush()
364364

365-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
365+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
366366

367367
assert len(payload.traces) == 1
368368
assert len(payload.traces[0].spans) == 1
@@ -394,7 +394,7 @@ def my_function(arg1: str, arg2: str):
394394
my_function("arg1", "arg2")
395395
galileo_context.flush()
396396

397-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
397+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
398398

399399
assert len(payload.traces) == 1
400400
assert len(payload.traces[0].spans) == 1
@@ -431,7 +431,7 @@ def nested_call(nested_query: str):
431431
output = nested_call(nested_query="input")
432432
galileo_context.flush()
433433

434-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
434+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
435435

436436
assert len(payload.traces) == 1
437437
assert len(payload.traces[0].spans) == 1
@@ -468,7 +468,7 @@ def nested_call(nested_query: str):
468468
output = nested_call(nested_query="input")
469469
galileo_context.flush()
470470

471-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
471+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
472472

473473
assert len(payload.traces) == 1
474474
assert len(payload.traces[0].spans) == 1
@@ -500,7 +500,7 @@ def retriever_call(query: str):
500500
retriever_call(query="input")
501501
galileo_context.flush()
502502

503-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
503+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
504504

505505
assert isinstance(payload.traces[0].spans[0], RetrieverSpan)
506506
assert payload.traces[0].spans[0].input == '{"query": "input"}'
@@ -524,7 +524,7 @@ def retriever_call(query: str):
524524
retriever_call(query="input")
525525
galileo_context.flush()
526526

527-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
527+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
528528

529529
assert isinstance(payload.traces[0].spans[0], RetrieverSpan)
530530
assert payload.traces[0].spans[0].input == '{"query": "input"}'
@@ -551,7 +551,7 @@ def retriever_call(query: str):
551551
retriever_call(query="input")
552552
galileo_context.flush()
553553

554-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
554+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
555555

556556
assert isinstance(payload.traces[0].spans[0], RetrieverSpan)
557557
assert payload.traces[0].spans[0].input == '{"query": "input"}'
@@ -578,7 +578,7 @@ def retriever_call(query: str):
578578
retriever_call(query="input")
579579
galileo_context.flush()
580580

581-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
581+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
582582

583583
assert isinstance(payload.traces[0].spans[0], RetrieverSpan)
584584
assert payload.traces[0].spans[0].input == '{"query": "input"}'
@@ -607,7 +607,7 @@ def foo():
607607

608608
galileo_context.flush()
609609

610-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
610+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
611611

612612
assert len(payload.traces) == 1
613613
assert len(payload.traces[0].spans) == 1
@@ -637,7 +637,7 @@ def foo():
637637

638638
galileo_context.flush()
639639

640-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
640+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
641641

642642
assert payload.session_id == UUID("6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9c")
643643

@@ -664,7 +664,7 @@ def foo():
664664

665665
galileo_context.flush()
666666

667-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
667+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
668668

669669
assert payload.session_id == UUID("6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9c")
670670

@@ -695,7 +695,7 @@ def foo():
695695

696696
galileo_context.flush()
697697

698-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
698+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
699699

700700
assert payload.session_id is None
701701

@@ -722,7 +722,7 @@ def foo():
722722

723723
galileo_context.flush()
724724

725-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
725+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
726726

727727
assert payload.session_id == UUID("6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9c")
728728

@@ -751,7 +751,7 @@ def foo(input: str):
751751

752752
logger.flush()
753753

754-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
754+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
755755

756756
assert payload.traces[0].input == "test input"
757757
assert payload.traces[0].output == "test output"

tests/test_experiments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def test_run_experiment_with_func(
412412
mock_get_dataset_instance.get_content.assert_called()
413413

414414
# check galileo_logger
415-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
415+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
416416

417417
assert len(payload.traces) == 1
418418
trace = payload.traces[0]
@@ -594,7 +594,7 @@ def runner(input):
594594
mock_get_dataset_instance.get_content.assert_called()
595595

596596
# check galileo_logger
597-
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
597+
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
598598
assert len(payload.traces) == 1
599599
assert (
600600
payload.traces[0].input == '{"input": {"question": "Which continent is Spain in?", "expected": "Europe"}}'

0 commit comments

Comments
 (0)