Skip to content
Draft

poc #263

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/galileo/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,24 @@ def start_session(
name=name, previous_session_id=previous_session_id, external_id=external_id
)

async def async_start_session(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ignore this change for now

self, name: Optional[str] = None, previous_session_id: Optional[str] = None, external_id: Optional[str] = None
) -> str:
"""
Async start a session in the active context logger instance.

Args:
name: The name of the session. If not provided, a session name will be generated automatically.
previous_session_id: The id of the previous session. Defaults to None.
external_id: The external id of the session. Defaults to None.

Returns:
str: The id of the newly created session.
"""
return await self.get_logger_instance().async_start_session(
name=name, previous_session_id=previous_session_id, external_id=external_id
)

def clear_session(self) -> None:
"""Clear the session in the active context logger instance."""
self.get_logger_instance().clear_session()
Expand Down
87 changes: 41 additions & 46 deletions src/galileo/logger/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,22 +1016,8 @@ def conclude(

return current_parent

@nop_sync
def flush(self) -> list[Trace]:
"""
Upload all traces to Galileo.

Returns:
-------
List[Trace]: The list of uploaded traces.
"""
if self.mode == "batch":
return self._flush_batch()
else:
self._logger.warning("Flushing in streaming mode is not supported.")
return list()

def _flush_batch(self):
async def _flush_batch(self, is_async: bool = False) -> list[Trace]:
# import pdb; pdb.set_trace()
if not self.traces:
self._logger.info("No traces to flush.")
return list()
Expand All @@ -1044,7 +1030,7 @@ def _flush_batch(self):

if self.local_metrics:
self._logger.info("Computing local metrics...")
# TODO: parallelize, possibly with ThreadPoolExecutor
# TODO: parallelize, possibly with ThreadPoolExecutor/asyncio
for trace in self.traces:
populate_local_metrics(trace, self.local_metrics)

Expand All @@ -1053,7 +1039,16 @@ def _flush_batch(self):
traces_ingest_request = TracesIngestRequest(
traces=self.traces, experiment_id=self.experiment_id, session_id=self.session_id
)
self._client.ingest_traces_sync(traces_ingest_request)

if is_async:
await self._client.ingest_traces(traces_ingest_request)
else:
# Use async_run() instead of asyncio.run() to work in all environments
# (Jupyter notebooks, pytest-asyncio, FastAPI, etc.)
from galileo_core.helpers.execution import async_run

async_run(self._client.ingest_traces(traces_ingest_request))

logged_traces = self.traces

self._logger.info("Successfully flushed %d traces.", len(logged_traces))
Expand All @@ -1072,39 +1067,39 @@ async def async_flush(self) -> list[Trace]:
List[Trace]: The list of uploaded workflows.
"""
if self.mode == "batch":
return await self._async_flush_batch()
return await self._flush_batch(is_async=True)
else:
self._logger.warning("Flushing in streaming mode is not supported.")
return list()

async def _async_flush_batch(self) -> list[Trace]:
if not self.traces:
self._logger.info("No traces to flush.")
return list()

current_parent = self.current_parent()
if current_parent is not None:
self._logger.info("Concluding the active trace...")
last_output = get_last_output(current_parent)
self.conclude(output=last_output, conclude_all=True)

if self.local_metrics:
self._logger.info("Computing metrics for local scorers...")
# TODO: parallelize, possibly with asyncio to_thread/gather
for trace in self.traces:
populate_local_metrics(trace, self.local_metrics)

self._logger.info("Flushing %d traces...", len(self.traces))

traces_ingest_request = TracesIngestRequest(traces=self.traces, session_id=self.session_id)
await self._client.ingest_traces(traces_ingest_request)
logged_traces = self.traces

self._logger.info("Successfully flushed %d traces.", len(logged_traces))
@nop_sync
def flush(self) -> list[Trace]:
"""
Upload all traces to Galileo.

self.traces = list()
self._parent_stack = deque()
return logged_traces
Returns:
-------
List[Trace]: The list of uploaded traces.
"""
if self.mode == "batch":
# This is bad because asyncio.run() fails in environments with existing event loops
# (e.g. jupyter notebooks, FastAPI, etc. would fail with "cannot be called from a running event loop"
# Even though flush() is sync, it can be called from async contexts like:
# - Jupyter notebooks (which have their own event loop)
# - pytest-asyncio tests (where @mark.asyncio creates an event loop)
# - FastAPI/Django async views (where the web framework has an event loop)
# - Any async function that calls sync code
# The EventLoopThreadPool approach works in ALL environments by using dedicated threads
# return asyncio.run(self._flush_batch(is_async=False))

# This is good because async_run() uses EventLoopThreadPool which works in all environments
# by running async code in dedicated threads with their own event loops
from galileo_core.helpers.execution import async_run

return async_run(self._flush_batch(is_async=False))
else:
self._logger.warning("Flushing in streaming mode is not supported.")
return list()

@nop_sync
def terminate(self) -> None:
Expand Down
12 changes: 0 additions & 12 deletions src/galileo/utils/core_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,18 +85,6 @@ async def ingest_traces(self, traces_ingest_request: TracesIngestRequest) -> dic
RequestMethod.POST, endpoint=Routes.traces.format(project_id=self.project_id), json=json
)

def ingest_traces_sync(self, traces_ingest_request: TracesIngestRequest) -> dict[str, str]:
if self.experiment_id:
traces_ingest_request.experiment_id = UUID(self.experiment_id)
elif self.log_stream_id:
traces_ingest_request.log_stream_id = UUID(self.log_stream_id)

json = traces_ingest_request.model_dump(mode="json")

return self._make_request(
RequestMethod.POST, endpoint=Routes.traces.format(project_id=self.project_id), json=json
)

async def ingest_spans(self, spans_ingest_request: SpansIngestRequest) -> dict[str, str]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to do what I did with ingest_traces to all other functions in this file. Do we really need both _make_request and _make_async_requestwhen under the hood request callsasync_runin the end (viagalileo-core`? We have a lot of logic duplication in this file

if self.experiment_id:
spans_ingest_request.experiment_id = UUID(self.experiment_id)
Expand Down
44 changes: 22 additions & 22 deletions tests/test_decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def llm_call(query: str):

galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -130,7 +130,7 @@ def llm_call(query: str):

galileo_context.flush(project="project-X", log_stream="log-stream-X")

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand All @@ -139,7 +139,7 @@ def llm_call(query: str):

galileo_context.flush(project="project-Y", log_stream="log-stream-Y")

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -211,7 +211,7 @@ def llm_call(query: str):
llm_call(query="input")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand All @@ -238,7 +238,7 @@ def my_function(arg1, arg2):
my_function(1, 2)
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -267,7 +267,7 @@ def my_function(system: Message, user: Message):
)
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -302,7 +302,7 @@ def my_function(system: Message, user: Message):
)
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -335,7 +335,7 @@ def my_function(arg1: str, arg2: str):
my_function("arg1", "arg2")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand All @@ -362,7 +362,7 @@ def my_function(arg1: str, arg2: str):
my_function("arg1", "arg2")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -394,7 +394,7 @@ def my_function(arg1: str, arg2: str):
my_function("arg1", "arg2")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -431,7 +431,7 @@ def nested_call(nested_query: str):
output = nested_call(nested_query="input")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -468,7 +468,7 @@ def nested_call(nested_query: str):
output = nested_call(nested_query="input")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -500,7 +500,7 @@ def retriever_call(query: str):
retriever_call(query="input")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert isinstance(payload.traces[0].spans[0], RetrieverSpan)
assert payload.traces[0].spans[0].input == '{"query": "input"}'
Expand All @@ -524,7 +524,7 @@ def retriever_call(query: str):
retriever_call(query="input")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert isinstance(payload.traces[0].spans[0], RetrieverSpan)
assert payload.traces[0].spans[0].input == '{"query": "input"}'
Expand All @@ -551,7 +551,7 @@ def retriever_call(query: str):
retriever_call(query="input")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert isinstance(payload.traces[0].spans[0], RetrieverSpan)
assert payload.traces[0].spans[0].input == '{"query": "input"}'
Expand All @@ -578,7 +578,7 @@ def retriever_call(query: str):
retriever_call(query="input")
galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert isinstance(payload.traces[0].spans[0], RetrieverSpan)
assert payload.traces[0].spans[0].input == '{"query": "input"}'
Expand Down Expand Up @@ -607,7 +607,7 @@ def foo():

galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
assert len(payload.traces[0].spans) == 1
Expand Down Expand Up @@ -637,7 +637,7 @@ def foo():

galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert payload.session_id == UUID("6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9c")

Expand All @@ -664,7 +664,7 @@ def foo():

galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert payload.session_id == UUID("6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9c")

Expand Down Expand Up @@ -695,7 +695,7 @@ def foo():

galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert payload.session_id is None

Expand All @@ -722,7 +722,7 @@ def foo():

galileo_context.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert payload.session_id == UUID("6c4e3f7e-4a9a-4e7e-8c1f-3a9a3a9a3a9c")

Expand Down Expand Up @@ -751,7 +751,7 @@ def foo(input: str):

logger.flush()

payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert payload.traces[0].input == "test input"
assert payload.traces[0].output == "test output"
Expand Down
4 changes: 2 additions & 2 deletions tests/test_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ def test_run_experiment_with_func(
mock_get_dataset_instance.get_content.assert_called()

# check galileo_logger
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]

assert len(payload.traces) == 1
trace = payload.traces[0]
Expand Down Expand Up @@ -594,7 +594,7 @@ def runner(input):
mock_get_dataset_instance.get_content.assert_called()

# check galileo_logger
payload = mock_core_api_instance.ingest_traces_sync.call_args[0][0]
payload = mock_core_api_instance.ingest_traces.call_args[0][0]
assert len(payload.traces) == 1
assert (
payload.traces[0].input == '{"input": {"question": "Which continent is Spain in?", "expected": "Europe"}}'
Expand Down
Loading
Loading