2121import boto3
2222import botocore
2323import httpx
24+ import pydantic
2425from asgiref .sync import async_to_sync
2526from botocore .auth import SigV4Auth
2627from botocore .awsrequest import AWSRequest
@@ -288,6 +289,17 @@ class InferenceTask(BaseModel):
288289 timeout : timedelta
289290
290291
292+ class InferenceResult (BaseModel ):
293+ model_config = ConfigDict (frozen = True )
294+
295+ pk : str
296+ return_code : int
297+ exec_duration : timedelta | None
298+ invoke_duration : timedelta | None
299+ outputs : list [InferenceIO ]
300+ sagemaker_shim_version : str
301+
302+
291303class Executor (ABC ):
292304 def __init__ (
293305 self ,
@@ -313,11 +325,15 @@ def __init__(
313325 use_warm_pool and settings .COMPONENTS_USE_WARM_POOL
314326 )
315327 self ._signing_key = signing_key
328+ self ._algorithm_model = algorithm_model
329+ self ._ground_truth = ground_truth
330+
331+ self ._exec_duration = None
332+ self ._invoke_duration = None
316333 self ._stdout = []
317334 self ._stderr = []
335+
318336 self .__s3_client = None
319- self ._algorithm_model = algorithm_model
320- self ._ground_truth = ground_truth
321337
322338 def provision (self , * , input_civs , input_prefixes ):
323339 # We cannot run everything async as it requires database access.
@@ -387,7 +403,15 @@ def stderr(self):
387403
388404 @property
389405 @abstractmethod
390- def duration (self ): ...
406+ def utilization_duration (self ): ...
407+
408+ @property
409+ def exec_duration (self ):
410+ return self ._exec_duration
411+
412+ @property
413+ def invoke_duration (self ):
414+ return self ._invoke_duration
391415
392416 @property
393417 @abstractmethod
@@ -437,12 +461,13 @@ def _max_memory_mb(self):
437461
438462 @property
439463 def compute_cost_euro_millicents (self ):
440- duration = self .duration
441- if duration is None :
464+ utilization_duration = self .utilization_duration
465+ if utilization_duration is None :
442466 return None
443467 else :
444468 return duration_to_millicents (
445- duration = duration , usd_cents_per_hour = self .usd_cents_per_hour
469+ duration = utilization_duration ,
470+ usd_cents_per_hour = self .usd_cents_per_hour ,
446471 )
447472
448473 @property
@@ -467,7 +492,7 @@ def _invocation_key(self):
467492 return safe_join (self ._invocation_prefix , "invocation.json" )
468493
469494 @property
470- def _result_key (self ):
495+ def _inference_result_key (self ):
471496 return safe_join (
472497 self ._io_prefix , ".sagemaker_shim" , "inference_result.json"
473498 )
@@ -802,11 +827,11 @@ def _get_upload_input_content_task(*, content, key):
802827 key = key ,
803828 )
804829
805- def _get_task_return_code (self ):
830+ def _get_inference_result (self ):
806831 try :
807832 response = self ._s3_client .get_object (
808833 Bucket = settings .COMPONENTS_OUTPUT_BUCKET_NAME ,
809- Key = self ._result_key ,
834+ Key = self ._inference_result_key ,
810835 )
811836 except botocore .exceptions .ClientError as error :
812837 if error .response ["Error" ]["Code" ] == "404" :
@@ -834,26 +859,29 @@ def _get_task_return_code(self):
834859 )
835860
836861 try :
837- result = json .loads (body .decode ("utf-8" ))
838- except JSONDecodeError :
862+ inference_result = InferenceResult .model_validate_json (
863+ json_data = body
864+ )
865+ except pydantic .ValidationError as error :
866+ logger .error (error , exc_info = True )
839867 raise ComponentException (
840868 "The invocation request did not return valid json"
841869 )
842870
843- logger .info (f"{ result = } " )
871+ logger .info (f"{ inference_result = } " )
844872
845- if result [ "pk" ] != self ._job_id :
873+ if inference_result . pk != self ._job_id :
846874 raise RuntimeError ("Wrong result key for this job" )
847875
848- try :
849- return int (result ["return_code" ])
850- except (KeyError , ValueError ):
851- raise ComponentException (
852- "The invocation response object is not valid"
853- )
876+ return inference_result
854877
855878 def _handle_completed_job (self ):
856- users_process_exit_code = self ._get_task_return_code ()
879+ inference_result = self ._get_inference_result ()
880+
881+ self ._exec_duration = inference_result .exec_duration
882+ self ._invoke_duration = inference_result .invoke_duration
883+
884+ users_process_exit_code = inference_result .return_code
857885
858886 if users_process_exit_code == 0 :
859887 # Job's a good un
0 commit comments