1- from typing import List , Optional , Union
1+ from typing import Iterable , List , Optional , Type , Union
22
33import yaml
44from pydantic import BaseModel , Field , PrivateAttr
77from seqspec .Region import Region , RegionInput
88
99from . import __version__
10+ from ._core import Assay as _RustAssay
11+ from ._core import Read as _RustRead
12+ from ._core import Region as _RustRegion
1013
1114
1215class SeqProtocol (BaseModel ):
@@ -159,6 +162,60 @@ def to_libkit(self) -> LibKit:
159162 )
160163
161164
165+ def coerce_protocol_kit_list (value , cls : Type [BaseModel ], modalities : Iterable [str ]):
166+ """
167+ Coerce a string or list of strings/objects/dicts into a list of protocol/kit objects (or None).
168+
169+ Supports:
170+ - "NovaSeq" -> [cls(protocol_id|kit_id="NovaSeq", name="NovaSeq", modality=m) for m in modalities]
171+ - ["A","B"] -> expanded per modality
172+ - [{"protocol_id": "...", ...}] -> cls(**dict)
173+ - [cls(...), "X", {...}] -> mixed inputs
174+ """
175+ if value is None :
176+ return None
177+
178+ # identify target family (protocol vs kit) by class
179+ is_protocol = cls .__name__ in {"SeqProtocol" , "LibProtocol" }
180+ is_kit = cls .__name__ in {"SeqKit" , "LibKit" }
181+
182+ if not (is_protocol or is_kit ):
183+ raise ValueError ("cls must be one of: SeqProtocol, LibProtocol, SeqKit, LibKit" )
184+
185+ def make_obj (val , modality : str ):
186+ if isinstance (val , cls ):
187+ return val
188+ if isinstance (val , dict ):
189+ return cls (** val )
190+ if isinstance (val , str ):
191+ if is_protocol :
192+ return cls (protocol_id = val , name = val , modality = modality )
193+ else :
194+ return cls (kit_id = val , name = val , modality = modality )
195+ raise TypeError (f"Unsupported item type for { cls .__name__ } : { type (val )!r} " )
196+
197+ if isinstance (value , str ):
198+ return [make_obj (value , m ) for m in modalities ]
199+
200+ if isinstance (value , list ):
201+ out = []
202+ for item in value :
203+ if isinstance (item , str ):
204+ out .extend (make_obj (item , m ) for m in modalities )
205+ else :
206+ # dict or already-typed object: keep as a single item
207+ # if it lacks modality, caller's responsibility (your spec usually includes it)
208+ out .append (make_obj (item , next (iter (modalities ), "" )))
209+ return out
210+
211+ # already a typed object (rare), wrap into list
212+ if isinstance (value , cls ):
213+ return [value ]
214+
215+ # last resort: pass through
216+ return value
217+
218+
162219class Assay (BaseModel ):
163220 seqspec_version : Optional [str ] = __version__
164221 assay_id : str
@@ -180,6 +237,9 @@ class Assay(BaseModel):
180237 # Not part of the public schema; populated when loading from disk.
181238 _spec_path : Optional [str ] = PrivateAttr (default = None )
182239
240+ def model_post_init (self , __context ) -> None :
241+ self .normalize_protocols_kits ()
242+
183243 def __repr__ (self ) -> str :
184244 rds = []
185245 rgns = []
@@ -214,6 +274,7 @@ def print_sequence(self):
214274 print ("\n " , end = "" )
215275
216276 def update_spec (self ):
277+ self .normalize_protocols_kits ()
217278 for r in self .library_spec :
218279 r .update_attr ()
219280
@@ -294,6 +355,82 @@ def insert_reads(
294355 self .sequence_spec .insert (insert_idx , read )
295356 insert_idx += 1
296357
358+ def normalize_protocols_kits (self ) -> None :
359+ """Normalize str-valued protocol/kit fields into lists of objects."""
360+ self .sequence_protocol = coerce_protocol_kit_list (
361+ self .sequence_protocol , SeqProtocol , self .modalities
362+ )
363+ self .sequence_kit = coerce_protocol_kit_list (
364+ self .sequence_kit , SeqKit , self .modalities
365+ )
366+ self .library_protocol = coerce_protocol_kit_list (
367+ self .library_protocol , LibProtocol , self .modalities
368+ )
369+ self .library_kit = coerce_protocol_kit_list (
370+ self .library_kit , LibKit , self .modalities
371+ )
372+
373+
374+ class RustAssay :
375+ __slots__ = ("_inner" ,)
376+
377+ def __init__ (self , inner : _RustAssay ) -> None :
378+ object .__setattr__ (self , "_inner" , inner )
379+
380+ # generic forwarding
381+ def __getattr__ (self , name ):
382+ return getattr (self ._inner , name )
383+
384+ def __setattr__ (self , name , value ):
385+ if name == "_inner" :
386+ return object .__setattr__ (self , name , value )
387+ return setattr (self ._inner , name , value )
388+
389+ # constructors
390+ @classmethod
391+ def from_model (cls , m : "Assay" ) -> "RustAssay" :
392+ return cls (_RustAssay .from_json (m .model_dump_json ()))
393+
394+ def snapshot (self ) -> "Assay" :
395+ return Assay .model_validate_json (self ._inner .to_json ())
396+
397+ # helpers: DTO outputs for downstream Python code
398+ def list_modalities (self ) -> List [str ]:
399+ return list (self ._inner .list_modalities ())
400+
401+ def get_libspec (self , modality : str ) -> Region :
402+ r : _RustRegion = self ._inner .get_libspec (modality )
403+ return Region .model_validate_json (r .to_json ())
404+
405+ def get_seqspec (self , modality : str ) -> List [Read ]:
406+ rlist : List [_RustRead ] = self ._inner .get_seqspec (modality )
407+ return [Read .model_validate_json (r .to_json ()) for r in rlist ]
408+
409+ def get_read (self , read_id : str ) -> Read :
410+ r : _RustRead = self ._inner .get_read (read_id )
411+ return Read .model_validate_json (r .to_json ())
412+
413+ def update_spec (self ) -> None :
414+ self ._inner .update_spec ()
415+
416+ def insert_reads (
417+ self , reads : List [Read ], modality : str , after : Optional [str ] = None
418+ ) -> None :
419+ # Convert DTOs to Rust via JSON (serde builds Vec<Read>)
420+ raw : List [_RustRead ] = [_RustRead .from_json (r .model_dump_json ()) for r in reads ]
421+ self ._inner .insert_reads (raw , modality , after )
422+
423+ def insert_regions (
424+ self , regions : List [Region ], modality : str , after : Optional [str ] = None
425+ ) -> None :
426+ raw : List [_RustRegion ] = [
427+ _RustRegion .from_json (r .model_dump_json ()) for r in regions
428+ ]
429+ self ._inner .insert_regions (raw , modality , after )
430+
431+ def __repr__ (self ) -> str :
432+ return self ._inner .__repr__ ()
433+
297434
298435class AssayInput (BaseModel ):
299436 """
0 commit comments