@@ -99,6 +99,10 @@ def to_json(self) -> Mapping[str, Any]:
9999 "controls" : [c .to_json () for c in self .controls ],
100100 }
101101
102+ @staticmethod
103+ def from_json (json : Mapping [str , Any ]) -> "Tokens" :
104+ return Tokens (tokens = json ["data" ], controls = [])
105+
102106 @staticmethod
103107 def from_token_ids (token_ids : Sequence [int ]) -> "Tokens" :
104108 return Tokens (token_ids , [])
@@ -173,20 +177,23 @@ def to_json(self) -> Mapping[str, Any]:
173177 "controls" : [control .to_json () for control in self .controls ],
174178 }
175179
180+ @staticmethod
181+ def from_json (json : Mapping [str , Any ]) -> "Text" :
182+ return Text .from_text (json ["data" ])
183+
176184 @staticmethod
177185 def from_text (text : str ) -> "Text" :
178186 return Text (text , [])
179187
180188
181- class Cropping :
189+ class Cropping ( NamedTuple ) :
182190 """
183191 Describes a quadratic crop of the file.
184192 """
185193
186- def __init__ (self , upper_left_x : int , upper_left_y : int , size : int ):
187- self .upper_left_x = upper_left_x
188- self .upper_left_y = upper_left_y
189- self .size = size
194+ upper_left_x : int
195+ upper_left_y : int
196+ size : int
190197
191198
192199class ImageControl (NamedTuple ):
@@ -254,7 +261,7 @@ def to_json(self) -> Mapping[str, Any]:
254261 return payload
255262
256263
257- class Image :
264+ class Image ( NamedTuple ) :
258265 """
259266 An image send as part of a prompt to a model. The image is represented as
260267 base64.
@@ -272,17 +279,11 @@ class Image:
272279 >>> image = Image.from_url(url)
273280 """
274281
275- def __init__ (
276- self ,
277- base_64 : str ,
278- cropping : Optional [Cropping ],
279- controls : Sequence [ImageControl ],
280- ):
281- # We use a base_64 reperesentation, because we want to embed the image
282- # into a prompt send in JSON.
283- self .base_64 = base_64
284- self .cropping = cropping
285- self .controls : Sequence [ImageControl ] = controls
282+ # We use a base_64 reperesentation, because we want to embed the image
283+ # into a prompt send in JSON.
284+ base_64 : str
285+ cropping : Optional [Cropping ]
286+ controls : Sequence [ImageControl ]
286287
287288 @classmethod
288289 def from_image_source (
@@ -357,7 +358,9 @@ def from_url_with_cropping(
357358 return cls .from_bytes (bytes , cropping = cropping , controls = controls or [])
358359
359360 @classmethod
360- def from_file (cls , path : Union [str , Path ], controls : Optional [Sequence [ImageControl ]] = None ):
361+ def from_file (
362+ cls , path : Union [str , Path ], controls : Optional [Sequence [ImageControl ]] = None
363+ ):
361364 """
362365 Load an image from disk and prepare it to be used in a prompt
363366 If they are not provided then the image will be [center cropped](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.CenterCrop)
@@ -412,6 +415,10 @@ def to_json(self) -> Dict[str, Any]:
412415 "controls" : [control .to_json () for control in self .controls ],
413416 }
414417
418+ @staticmethod
419+ def from_json (json : Mapping [str , Any ]) -> "Image" :
420+ return Image (base_64 = json ["data" ], cropping = None , controls = [])
421+
415422 def to_image (self ) -> PILImage :
416423 return PIL .Image .open (io .BytesIO (base64 .b64decode (self .base_64 )))
417424
@@ -464,6 +471,29 @@ def from_tokens(
464471 def to_json (self ) -> Sequence [Mapping [str , Any ]]:
465472 return [_to_json (item ) for item in self .items ]
466473
474+ @staticmethod
475+ def from_json (items_json : Sequence [Mapping [str , Any ]]) -> "Prompt" :
476+ return Prompt (
477+ [
478+ item
479+ for item in (_prompt_item_from_json (item ) for item in items_json )
480+ if item
481+ ]
482+ )
483+
484+
485+ def _prompt_item_from_json (item : Mapping [str , Any ]) -> Optional [PromptItem ]:
486+ item_type = item .get ("type" )
487+ if item_type == "text" :
488+ return Text .from_json (item )
489+ if item_type == "image" :
490+ return Image .from_json (item )
491+ if item_type == "token_ids" :
492+ return Tokens .from_json (item )
493+ # Skip item instead of raising an error to prevent failures of old clients
494+ # when item types are extended
495+ return None
496+
467497
468498def _to_json (item : PromptItem ) -> Mapping [str , Any ]:
469499 if hasattr (item , "to_json" ):
0 commit comments