diff --git a/deepeval/tracing/tracing.py b/deepeval/tracing/tracing.py index e88a0a10a..9456e0f5c 100644 --- a/deepeval/tracing/tracing.py +++ b/deepeval/tracing/tracing.py @@ -156,6 +156,7 @@ def __del__(self): class TraceManager: def __init__(self): + self._state_lock = threading.RLock() self.traces: List[Trace] = [] self.active_traces: Dict[str, Trace] = {} # Map of trace_uuid to Trace self.active_spans: Dict[str, BaseSpan] = ( @@ -270,10 +271,11 @@ def start_new_trace( metric_collection=metric_collection, confident_api_key=self.confident_api_key, ) - self.active_traces[trace_uuid] = new_trace - self.traces.append(new_trace) - if self.evaluation_loop: - self.traces_to_evaluate_order.append(trace_uuid) + with self._state_lock: + self.active_traces[trace_uuid] = new_trace + self.traces.append(new_trace) + if self.evaluation_loop: + self.traces_to_evaluate_order.append(trace_uuid) # Associate the current Golden with this trace so we can # later evaluate traces against the correct golden, even if more traces # are created than goldens or the order interleaves. @@ -291,8 +293,10 @@ def start_new_trace( def end_trace(self, trace_uuid: str): """End a specific trace by its UUID.""" + with self._state_lock: + if trace_uuid not in self.active_traces: + return - if trace_uuid in self.active_traces: trace = self.active_traces[trace_uuid] trace.end_time = ( perf_counter() if trace.end_time is None else trace.end_time @@ -354,66 +358,86 @@ def end_trace(self, trace_uuid: str): def set_trace_status(self, trace_uuid: str, status: TraceSpanStatus): """Manually set the status of a trace.""" - if trace_uuid in self.active_traces: - trace = self.active_traces[trace_uuid] - trace.status = status + with self._state_lock: + if trace_uuid in self.active_traces: + trace = self.active_traces[trace_uuid] + trace.status = status def add_span(self, span: BaseSpan): """Add a span to the active spans dictionary.""" - self.active_spans[span.uuid] = span + with self._state_lock: + self.active_spans[span.uuid] = span def remove_span(self, span_uuid: str): """Remove a span from the active spans dictionary.""" - if span_uuid in self.active_spans: - del self.active_spans[span_uuid] + with self._state_lock: + if span_uuid in self.active_spans: + del self.active_spans[span_uuid] def add_span_to_trace(self, span: BaseSpan): """Add a span to its trace.""" - trace_uuid = span.trace_uuid - if trace_uuid not in self.active_traces: - raise ValueError( - f"Trace with UUID {trace_uuid} does not exist. A span must have a valid trace." - ) + with self._state_lock: + trace_uuid = span.trace_uuid + if trace_uuid not in self.active_traces: + raise ValueError( + f"Trace with UUID {trace_uuid} does not exist. A span must have a valid trace." + ) - trace = self.active_traces[trace_uuid] + trace = self.active_traces[trace_uuid] - # If this is a root span (no parent), add it to the trace's root_spans - if not span.parent_uuid: - trace.root_spans.append(span) - else: - # This is a child span, find its parent and add it to the parent's children - parent_span = self.get_span_by_uuid(span.parent_uuid) - if parent_span: + # If this is a root span (no parent), add it to the trace's root_spans + if not span.parent_uuid: + trace.root_spans.append(span) + else: + # This is a child span, find its parent and add it to the parent's children + parent_span = self.get_span_by_uuid(span.parent_uuid) + if parent_span: - if ( - parent_span.name == EVAL_DUMMY_SPAN_NAME - ): # ignored span for evaluation - span.parent_uuid = None - trace.root_spans.remove(parent_span) + if ( + parent_span.name == EVAL_DUMMY_SPAN_NAME + ): # ignored span for evaluation + span.parent_uuid = None + trace.root_spans.remove(parent_span) + trace.root_spans.append(span) + return + + parent_span.children.append(span) + else: trace.root_spans.append(span) - return - - parent_span.children.append(span) - else: - trace.root_spans.append(span) def get_trace_by_uuid(self, trace_uuid: str) -> Optional[Trace]: """Get a trace by its UUID.""" - return self.active_traces.get(trace_uuid) + with self._state_lock: + return self.active_traces.get(trace_uuid) def get_span_by_uuid(self, span_uuid: str) -> Optional[BaseSpan]: """Get a span by its UUID.""" - return self.active_spans.get(span_uuid) + with self._state_lock: + return self.active_spans.get(span_uuid) + + def is_trace_active(self, trace_uuid: str) -> bool: + with self._state_lock: + return trace_uuid in self.active_traces + + def get_active_spans_for_trace(self, trace_uuid: str) -> List[BaseSpan]: + with self._state_lock: + return [ + span + for span in self.active_spans.values() + if span.trace_uuid == trace_uuid + ] def get_all_traces(self) -> List[Trace]: """Get all traces.""" - return self.traces + with self._state_lock: + return list(self.traces) def clear_traces(self): """Clear all traces.""" - self.traces = [] - self.active_traces = {} - self.active_spans = {} + with self._state_lock: + self.traces = [] + self.active_traces = {} + self.active_spans = {} def get_trace_dict(self, trace: Trace) -> Dict: """Convert a trace to a dictionary.""" @@ -421,7 +445,9 @@ def get_trace_dict(self, trace: Trace) -> Dict: def get_all_traces_dict(self) -> List[Dict]: """Get all traces as dictionaries.""" - return [self.get_trace_dict(trace) for trace in self.traces] + with self._state_lock: + traces_snapshot = list(self.traces) + return [self.get_trace_dict(trace) for trace in traces_snapshot] def _print_trace_status( self, @@ -959,7 +985,7 @@ def __enter__(self): # (a previous failed async operation might leave a dead trace in context) if ( current_trace - and current_trace.uuid in trace_manager.active_traces + and trace_manager.is_trace_active(current_trace.uuid) ): self.trace_uuid = current_trace.uuid else: @@ -1095,11 +1121,11 @@ def __exit__(self, exc_type, exc_val, exc_tb): if current_span.status == TraceSpanStatus.ERRORED: current_trace.status = TraceSpanStatus.ERRORED if current_trace.uuid == current_span.trace_uuid: - other_active_spans = [ - span - for span in trace_manager.active_spans.values() - if span.trace_uuid == current_span.trace_uuid - ] + other_active_spans = ( + trace_manager.get_active_spans_for_trace( + current_span.trace_uuid + ) + ) if not other_active_spans: trace_manager.end_trace(current_span.trace_uuid)