@@ -967,6 +967,93 @@ def observe(
967967 def decorator (func ):
968968 func_name = func .__name__ # Get func_name outside wrappers
969969
970+ # Async generator function
971+ if inspect .isasyncgenfunction (func ):
972+
973+ @functools .wraps (func )
974+ def asyncgen_wrapper (* args , ** func_kwargs ):
975+
976+ sig = inspect .signature (func )
977+ bound = sig .bind (* args , ** func_kwargs )
978+ bound .apply_defaults ()
979+
980+ complete_kwargs = dict (bound .arguments )
981+ if "self" in complete_kwargs :
982+ complete_kwargs ["self" ] = replace_self_with_class_name (
983+ complete_kwargs ["self" ]
984+ )
985+ observer_kwargs = {
986+ "observe_kwargs" : observe_kwargs ,
987+ "function_kwargs" : complete_kwargs ,
988+ }
989+
990+ observer = Observer (
991+ type ,
992+ metrics = metrics ,
993+ metric_collection = metric_collection ,
994+ func_name = func_name ,
995+ ** observer_kwargs ,
996+ )
997+ observer .__enter__ ()
998+ agen = func (* args , ** func_kwargs )
999+
1000+ async def gen ():
1001+ try :
1002+ async for chunk in agen :
1003+ yield chunk
1004+ observer .__exit__ (None , None , None )
1005+ except Exception as e :
1006+ observer .__exit__ (type (e ), e , e .__traceback__ )
1007+ raise
1008+
1009+ return gen ()
1010+
1011+ setattr (asyncgen_wrapper , "_is_deepeval_observed" , True )
1012+ return asyncgen_wrapper
1013+
1014+ # Sync generator function
1015+ if inspect .isgeneratorfunction (func ):
1016+
1017+ @functools .wraps (func )
1018+ def gen_wrapper (* args , ** func_kwargs ):
1019+
1020+ sig = inspect .signature (func )
1021+ bound = sig .bind (* args , ** func_kwargs )
1022+ bound .apply_defaults ()
1023+ complete_kwargs = dict (bound .arguments )
1024+
1025+ if "self" in complete_kwargs :
1026+ complete_kwargs ["self" ] = replace_self_with_class_name (
1027+ complete_kwargs ["self" ]
1028+ )
1029+ observer_kwargs = {
1030+ "observe_kwargs" : observe_kwargs ,
1031+ "function_kwargs" : make_json_serializable (complete_kwargs ),
1032+ }
1033+
1034+ observer = Observer (
1035+ type ,
1036+ metrics = metrics ,
1037+ metric_collection = metric_collection ,
1038+ func_name = func_name ,
1039+ ** observer_kwargs ,
1040+ )
1041+ observer .__enter__ ()
1042+ original_gen = func (* args , ** func_kwargs )
1043+
1044+ def gen ():
1045+ try :
1046+ yield from original_gen
1047+ observer .__exit__ (None , None , None )
1048+ except Exception as e :
1049+ observer .__exit__ (type (e ), e , e .__traceback__ )
1050+ raise
1051+
1052+ return gen ()
1053+
1054+ setattr (gen_wrapper , "_is_deepeval_observed" , True )
1055+ return gen_wrapper
1056+
9701057 if asyncio .iscoroutinefunction (func ):
9711058
9721059 @functools .wraps (func )
0 commit comments