99import os
1010import secrets
1111from abc import ABC , abstractmethod
12+ from datetime import timedelta
1213from json import JSONDecodeError
1314from math import ceil
1415from pathlib import Path
3031from django .utils ._os import safe_join
3132from django .utils .functional import cached_property
3233from panimg .image_builders import image_builder_mhd , image_builder_tiff
34+ from pydantic import BaseModel , ConfigDict
35+ from pydantic_core import to_jsonable_python
3336
3437from grandchallenge .cases .tasks import import_images
3538from grandchallenge .components .backends .exceptions import (
@@ -266,6 +269,25 @@ async def s3_stream_response(
266269 raise
267270
268271
272+ class InferenceIO (BaseModel ):
273+ model_config = ConfigDict (frozen = True )
274+
275+ relative_path : str
276+ bucket_name : str
277+ bucket_key : str
278+ decompress : bool
279+
280+
281+ class InferenceTask (BaseModel ):
282+ model_config = ConfigDict (frozen = True )
283+
284+ pk : str
285+ inputs : list [InferenceIO ]
286+ output_bucket_name : str
287+ output_prefix : str
288+ timeout : timedelta
289+
290+
269291class Executor (ABC ):
270292 def __init__ (
271293 self ,
@@ -285,7 +307,7 @@ def __init__(
285307 self ._job_id = job_id
286308 self ._exec_image_repo_tag = exec_image_repo_tag
287309 self ._memory_limit = memory_limit
288- self ._time_limit = time_limit
310+ self ._time_limit = timedelta ( seconds = time_limit )
289311 self ._requires_gpu_type = requires_gpu_type
290312 self ._use_warm_pool = (
291313 use_warm_pool and settings .COMPONENTS_USE_WARM_POOL
@@ -540,16 +562,16 @@ def _get_provisioning_tasks(self, *, input_civs, input_prefixes):
540562 ):
541563 provisioning_tasks .append (civ_provisioning_task .task )
542564 invocation_inputs .append (
543- {
544- " relative_path" : str (
565+ InferenceIO (
566+ relative_path = str (
545567 os .path .relpath (
546568 civ_provisioning_task .key , self ._io_prefix
547569 )
548570 ),
549- " bucket_name" : settings .COMPONENTS_INPUT_BUCKET_NAME ,
550- " bucket_key" : civ_provisioning_task .key ,
551- " decompress" : civ .decompress ,
552- }
571+ bucket_name = settings .COMPONENTS_INPUT_BUCKET_NAME ,
572+ bucket_key = civ_provisioning_task .key ,
573+ decompress = civ .decompress ,
574+ )
553575 )
554576
555577 provisioning_tasks .append (
@@ -667,14 +689,17 @@ def _get_civ_provisioning_tasks(self, *, civ, input_prefixes):
667689 def _get_create_invocation_json_task (self , * , invocation_inputs ):
668690 return self ._get_upload_input_content_task (
669691 content = json .dumps (
670- [
671- {
672- "pk" : self ._job_id ,
673- "inputs" : invocation_inputs ,
674- "output_bucket_name" : settings .COMPONENTS_OUTPUT_BUCKET_NAME ,
675- "output_prefix" : self ._io_prefix ,
676- }
677- ]
692+ to_jsonable_python (
693+ [
694+ InferenceTask (
695+ pk = self ._job_id ,
696+ inputs = invocation_inputs ,
697+ output_bucket_name = settings .COMPONENTS_OUTPUT_BUCKET_NAME ,
698+ output_prefix = self ._io_prefix ,
699+ timeout = self ._time_limit ,
700+ )
701+ ]
702+ )
678703 ).encode ("utf-8" ),
679704 key = self ._invocation_key ,
680705 )
0 commit comments