1212# Import Literal with Python 3.7 fallback
1313from typing_extensions import Literal
1414
15+ from aleph_alpha_client import Text
16+
1517from aleph_alpha_client .prompt import ControlTokenOverlap , Image , Prompt , PromptItem
1618
1719
@@ -204,6 +206,20 @@ def from_json(score: Any) -> "TextScore":
204206 score = score ["score" ],
205207 )
206208
209+ class TextScoreWithRaw (NamedTuple ):
210+ start : int
211+ length : int
212+ score : float
213+ text : str
214+
215+ @staticmethod
216+ def from_text_score (score : TextScore , prompt : Text ) -> "TextScoreWithRaw" :
217+ return TextScoreWithRaw (
218+ start = score .start ,
219+ length = score .length ,
220+ score = score .score ,
221+ text = prompt .text [score .start :score .start + score .length ],
222+ )
207223
208224class ImageScore (NamedTuple ):
209225 left : float
@@ -236,6 +252,20 @@ def from_json(score: Any) -> "TargetScore":
236252 score = score ["score" ],
237253 )
238254
255+ class TargetScoreWithRaw (NamedTuple ):
256+ start : int
257+ length : int
258+ score : float
259+ text : str
260+
261+ @staticmethod
262+ def from_target_score (score : TargetScore , target : str ) -> "TargetScoreWithRaw" :
263+ return TargetScoreWithRaw (
264+ start = score .start ,
265+ length = score .length ,
266+ score = score .score ,
267+ text = target [score .start :score .start + score .length ],
268+ )
239269
240270class TokenScore (NamedTuple ):
241271 score : float
@@ -275,23 +305,37 @@ def in_pixels(self, prompt_item: PromptItem) -> "ImagePromptItemExplanation":
275305
276306
277307class TextPromptItemExplanation (NamedTuple ):
278- scores : List [TextScore ]
308+ scores : List [Union [ TextScore , TextScoreWithRaw ] ]
279309
280310 @staticmethod
281311 def from_json (item : Dict [str , Any ]) -> "TextPromptItemExplanation" :
282312 return TextPromptItemExplanation (
283313 scores = [TextScore .from_json (score ) for score in item ["scores" ]]
284314 )
315+
316+ def with_text (self , prompt : Text ) -> "TextPromptItemExplanation" :
317+ return TextPromptItemExplanation (
318+ scores = [TextScoreWithRaw .from_text_score (score , prompt ) if isinstance (score , TextScore ) else score for score in self .scores ]
319+ )
320+
285321
286322
287323class TargetPromptItemExplanation (NamedTuple ):
288- scores : List [TargetScore ]
324+ scores : List [Union [ TargetScore , TargetScoreWithRaw ] ]
289325
290326 @staticmethod
291327 def from_json (item : Dict [str , Any ]) -> "TargetPromptItemExplanation" :
292328 return TargetPromptItemExplanation (
293329 scores = [TargetScore .from_json (score ) for score in item ["scores" ]]
294330 )
331+
332+ def with_text (self , prompt : str ) -> "TargetPromptItemExplanation" :
333+ return TargetPromptItemExplanation (
334+ scores = [TargetScoreWithRaw .from_target_score (score , prompt ) if isinstance (score , TargetScore ) else score for score in self .scores ]
335+ )
336+
337+
338+
295339
296340
297341class TokenPromptItemExplanation (NamedTuple ):
@@ -352,6 +396,31 @@ def with_image_prompt_items_in_pixels(self, prompt: Prompt) -> "Explanation":
352396 ],
353397 )
354398
399+ def with_text_from_prompt (self , prompt : Prompt , target : str ) -> "Explanation" :
400+ items : List [Union [
401+ TextPromptItemExplanation ,
402+ ImagePromptItemExplanation ,
403+ TargetPromptItemExplanation ,
404+ TokenPromptItemExplanation ,
405+ ]] = []
406+ for item_index , item in enumerate (self .items ):
407+ if isinstance (item , TextPromptItemExplanation ):
408+ # separate variable to fix linting error
409+ prompt_item = prompt .items [item_index ]
410+ if isinstance (prompt_item , Text ):
411+ items .append (item .with_text (prompt_item ))
412+ else :
413+ items .append (item )
414+ elif isinstance (item , TargetPromptItemExplanation ):
415+ items .append (item .with_text (target ))
416+ else :
417+ items .append (item )
418+ return Explanation (
419+ target = self .target ,
420+ items = items ,
421+ )
422+
423+
355424
356425class ExplanationResponse (NamedTuple ):
357426 model_version : str
@@ -375,3 +444,12 @@ def with_image_prompt_items_in_pixels(
375444 for explanation in self .explanations
376445 ]
377446 return ExplanationResponse (self .model_version , mapped_explanations )
447+
448+ def with_text_from_prompt (
449+ self , request : ExplanationRequest
450+ ) -> "ExplanationResponse" :
451+ mapped_explanations = [
452+ explanation .with_text_from_prompt (request .prompt , request .target )
453+ for explanation in self .explanations
454+ ]
455+ return ExplanationResponse (self .model_version , mapped_explanations )
0 commit comments