@@ -63,9 +63,32 @@ class TokenizedChunk:
6363 byte_ends : np .ndarray
6464
6565
66+ @dataclass (frozen = True )
67+ class _TokenBoundarySpans :
68+ byte_starts : tuple [int , ...]
69+ byte_ends : tuple [int , ...]
70+
71+
6672@dataclass (frozen = True )
6773class LiteralExample :
6874 abs_delta_bits : float
75+ dataset_name : str
76+ doc_preview : str
77+ segment_text : str
78+ segment_byte_start : int
79+ segment_byte_end : int
80+ model_a_token_spans : _TokenBoundarySpans
81+ model_b_token_spans : _TokenBoundarySpans
82+
83+
84+ @dataclass (frozen = True )
85+ class _TokenBoundaryIndex :
86+ byte_starts : np .ndarray
87+ byte_ends : np .ndarray
88+
89+
90+ @dataclass (frozen = True )
91+ class _RenderedLiteralExample :
6992 dataset_name : str
7093 doc_preview : str
7194 model_a_token_boundaries : str
@@ -175,6 +198,8 @@ def add_document(
175198 byte_offsets = char_to_byte_offsets (document .text )
176199 worst_positive_segment : WorstSegment | None = None
177200 worst_negative_segment : WorstSegment | None = None
201+ token_boundary_index_a = _token_boundary_index (tokenized_a ) if tokenized_a is not None else None
202+ token_boundary_index_b = _token_boundary_index (tokenized_b ) if tokenized_b is not None else None
178203
179204 for match in _SEGMENT_RE .finditer (document .text ):
180205 segment = match .group (0 )
@@ -217,7 +242,7 @@ def add_document(
217242 loss_b = segment_loss_b ,
218243 num_bytes = segment_bytes ,
219244 )
220- if tokenized_a is not None and tokenized_b is not None :
245+ if token_boundary_index_a is not None and token_boundary_index_b is not None :
221246 self ._maybe_record_literal_example (
222247 literal_key = literal_key ,
223248 document = document ,
@@ -227,8 +252,8 @@ def add_document(
227252 segment_char_start = match .start (),
228253 segment_char_end = match .end (),
229254 segment_delta_bits = segment_delta_bits ,
230- tokenized_a = tokenized_a ,
231- tokenized_b = tokenized_b ,
255+ token_boundary_index_a = token_boundary_index_a ,
256+ token_boundary_index_b = token_boundary_index_b ,
232257 )
233258
234259 if segment_delta_bits == 0.0 :
@@ -363,29 +388,32 @@ def _maybe_record_literal_example(
363388 segment_char_start : int ,
364389 segment_char_end : int ,
365390 segment_delta_bits : float ,
366- tokenized_a : TokenizedDocument ,
367- tokenized_b : TokenizedDocument ,
391+ token_boundary_index_a : _TokenBoundaryIndex ,
392+ token_boundary_index_b : _TokenBoundaryIndex ,
368393 ) -> None :
369- candidate = LiteralExample (
370- abs_delta_bits = abs (segment_delta_bits ),
394+ abs_delta_bits = abs (segment_delta_bits )
395+ current = self .literal_examples .get (literal_key )
396+ if current is not None and abs_delta_bits <= current .abs_delta_bits :
397+ return
398+
399+ self .literal_examples [literal_key ] = LiteralExample (
400+ abs_delta_bits = abs_delta_bits ,
371401 dataset_name = document .dataset_name ,
372402 doc_preview = preview_text_window (document .text , segment_char_start , segment_char_end ),
373- model_a_token_boundaries = render_token_boundaries (
374- segment_text = segment_text ,
403+ segment_text = segment_text ,
404+ segment_byte_start = segment_byte_start ,
405+ segment_byte_end = segment_byte_end ,
406+ model_a_token_spans = _overlapping_token_spans (
407+ token_boundary_index_a ,
375408 segment_byte_start = segment_byte_start ,
376409 segment_byte_end = segment_byte_end ,
377- tokenized = tokenized_a ,
378410 ),
379- model_b_token_boundaries = render_token_boundaries (
380- segment_text = segment_text ,
411+ model_b_token_spans = _overlapping_token_spans (
412+ token_boundary_index_b ,
381413 segment_byte_start = segment_byte_start ,
382414 segment_byte_end = segment_byte_end ,
383- tokenized = tokenized_b ,
384415 ),
385416 )
386- current = self .literal_examples .get (literal_key )
387- if current is None or candidate .abs_delta_bits > current .abs_delta_bits :
388- self .literal_examples [literal_key ] = candidate
389417
390418
391419def iter_raw_text_documents (
@@ -655,6 +683,68 @@ def render_token_boundaries(
655683 return "|" + "|" .join (pieces ) + "|"
656684
657685
686+ def _token_boundary_index (tokenized : TokenizedDocument ) -> _TokenBoundaryIndex :
687+ valid = (tokenized .byte_starts >= 0 ) & (tokenized .byte_ends > tokenized .byte_starts )
688+ return _TokenBoundaryIndex (
689+ byte_starts = tokenized .byte_starts [valid ].astype (np .int64 , copy = False ),
690+ byte_ends = tokenized .byte_ends [valid ].astype (np .int64 , copy = False ),
691+ )
692+
693+
694+ def _overlapping_token_spans (
695+ token_boundary_index : _TokenBoundaryIndex ,
696+ * ,
697+ segment_byte_start : int ,
698+ segment_byte_end : int ,
699+ ) -> _TokenBoundarySpans :
700+ first_index = int (np .searchsorted (token_boundary_index .byte_ends , segment_byte_start , side = "right" ))
701+ byte_starts : list [int ] = []
702+ byte_ends : list [int ] = []
703+ for token_start , token_end in zip (
704+ token_boundary_index .byte_starts [first_index :],
705+ token_boundary_index .byte_ends [first_index :],
706+ strict = True ,
707+ ):
708+ if token_start >= segment_byte_end :
709+ break
710+ if token_end <= segment_byte_start :
711+ continue
712+ byte_starts .append (int (token_start ))
713+ byte_ends .append (int (token_end ))
714+
715+ return _TokenBoundarySpans (byte_starts = tuple (byte_starts ), byte_ends = tuple (byte_ends ))
716+
717+
718+ def _tokenized_document_from_boundary_spans (token_spans : _TokenBoundarySpans ) -> TokenizedDocument :
719+ byte_starts = np .asarray (token_spans .byte_starts , dtype = np .int32 )
720+ byte_ends = np .asarray (token_spans .byte_ends , dtype = np .int32 )
721+ return TokenizedDocument (
722+ token_ids = np .zeros (len (byte_starts ), dtype = np .int32 ),
723+ byte_starts = byte_starts ,
724+ byte_ends = byte_ends ,
725+ num_bytes = 0 ,
726+ )
727+
728+
729+ def _render_literal_example (example : LiteralExample ) -> _RenderedLiteralExample :
730+ return _RenderedLiteralExample (
731+ dataset_name = example .dataset_name ,
732+ doc_preview = example .doc_preview ,
733+ model_a_token_boundaries = render_token_boundaries (
734+ segment_text = example .segment_text ,
735+ segment_byte_start = example .segment_byte_start ,
736+ segment_byte_end = example .segment_byte_end ,
737+ tokenized = _tokenized_document_from_boundary_spans (example .model_a_token_spans ),
738+ ),
739+ model_b_token_boundaries = render_token_boundaries (
740+ segment_text = example .segment_text ,
741+ segment_byte_start = example .segment_byte_start ,
742+ segment_byte_end = example .segment_byte_end ,
743+ tokenized = _tokenized_document_from_boundary_spans (example .model_b_token_spans ),
744+ ),
745+ )
746+
747+
658748def bucket_for_segment (segment : str ) -> str :
659749 if segment .isspace ():
660750 if "\t " in segment or "\r " in segment :
@@ -810,31 +900,41 @@ def _top_literal_rows(
810900 direction : str ,
811901 limit : int ,
812902) -> list [dict [str , Any ]]:
813- rows : list [dict [str , Any ]] = []
814- for ( bucket , literal ) , stats in literal_stats .items ():
815- example = literal_examples . get (( bucket , literal ))
903+ rows : list [tuple [ tuple [ str , str ], dict [str , Any ] ]] = []
904+ for literal_key , stats in literal_stats .items ():
905+ bucket , literal = literal_key
816906 stats_row = stats .as_dict (literal )
817907 row = {
818908 "name" : stats_row ["name" ],
819909 "bucket" : bucket ,
820- "example_dataset" : example .dataset_name if example is not None else None ,
821- "example_doc_preview" : example .doc_preview if example is not None else None ,
822- "model_a_token_boundaries" : example .model_a_token_boundaries if example is not None else None ,
823- "model_b_token_boundaries" : example .model_b_token_boundaries if example is not None else None ,
824910 "documents" : stats_row ["documents" ],
825911 "bytes" : stats_row ["bytes" ],
826912 "model_a_bpb" : stats_row ["model_a_bpb" ],
827913 "model_b_bpb" : stats_row ["model_b_bpb" ],
828914 "gap_bpb" : stats_row ["gap_bpb" ],
829915 "delta_bits" : stats_row ["delta_bits" ],
830916 }
831- rows .append (row )
917+ rows .append (( literal_key , row ) )
832918
833919 if direction == "positive" :
834- rows = [row for row in rows if row ["delta_bits" ] > 0 ]
835- rows .sort (key = lambda row : row ["delta_bits" ], reverse = True )
920+ rows = [( literal_key , row ) for literal_key , row in rows if row ["delta_bits" ] > 0 ]
921+ rows .sort (key = lambda item : item [ 1 ] ["delta_bits" ], reverse = True )
836922 else :
837- rows = [row for row in rows if row ["delta_bits" ] < 0 ]
838- rows .sort (key = lambda row : row ["delta_bits" ])
923+ rows = [(literal_key , row ) for literal_key , row in rows if row ["delta_bits" ] < 0 ]
924+ rows .sort (key = lambda item : item [1 ]["delta_bits" ])
925+
926+ selected_rows : list [dict [str , Any ]] = []
927+ for literal_key , row in rows [:limit ]:
928+ example = literal_examples .get (literal_key )
929+ rendered_example = _render_literal_example (example ) if example is not None else None
930+ row ["example_dataset" ] = rendered_example .dataset_name if rendered_example is not None else None
931+ row ["example_doc_preview" ] = rendered_example .doc_preview if rendered_example is not None else None
932+ row ["model_a_token_boundaries" ] = (
933+ rendered_example .model_a_token_boundaries if rendered_example is not None else None
934+ )
935+ row ["model_b_token_boundaries" ] = (
936+ rendered_example .model_b_token_boundaries if rendered_example is not None else None
937+ )
938+ selected_rows .append (row )
839939
840- return rows [: limit ]
940+ return selected_rows
0 commit comments