11import argparse
2- import itertools
32import logging
4- import random
5- import sys
6- import threading
7- import time
8- import warnings
93from typing import Any , Dict , List , Optional
104
115from huggingface_hub import model_info
126from huggingface_hub .errors import HfHubHTTPError
13- from langchain_core .callbacks import BaseCallbackHandler
147from pymilvus .exceptions import MilvusException
158
169RAG_EMOJI = "🧠🧠🧠🧠🧠"
2316
2417from mmore .profiler import enable_profiling_from_env , profile_function
2518from mmore .rag .pipeline import RAGPipeline
19+ from mmore .ragcli_console import (
20+ Spinner ,
21+ TimingHandler ,
22+ print_in_color ,
23+ quiet_noisy_libs ,
24+ str_green ,
25+ str_in_color ,
26+ )
2627from mmore .run_rag import RAGInferenceConfig
2728from mmore .utils import load_config
2829
@@ -279,105 +280,6 @@ def _count_tokens(self, text: Optional[str]) -> Optional[int]:
279280 return None
280281
281282
282- class TimingHandler (BaseCallbackHandler ):
283- """Collects retrieval/generation wall times and token usage from callbacks."""
284-
285- def __init__ (self ):
286- self .retrieval_time : Optional [float ] = None
287- self .generation_time : Optional [float ] = None
288- self .completion_tokens : Optional [int ] = None
289- self ._starts : Dict [Any , float ] = {}
290-
291- def on_retriever_start (self , serialized , query , * , run_id , ** kwargs ):
292- self ._starts [run_id ] = time .perf_counter ()
293-
294- def on_retriever_end (self , documents , * , run_id , ** kwargs ):
295- if run_id in self ._starts :
296- self .retrieval_time = time .perf_counter () - self ._starts .pop (run_id )
297-
298- def on_llm_start (self , serialized , prompts , * , run_id , ** kwargs ):
299- self ._starts [run_id ] = time .perf_counter ()
300-
301- def on_chat_model_start (self , serialized , messages , * , run_id , ** kwargs ):
302- self ._starts [run_id ] = time .perf_counter ()
303-
304- def on_llm_end (self , response , * , run_id , ** kwargs ):
305- if run_id in self ._starts :
306- self .generation_time = time .perf_counter () - self ._starts .pop (run_id )
307- self .completion_tokens = _output_tokens (response )
308-
309-
310- def _output_tokens (response ) -> Optional [int ]:
311- """Generated-token count if the provider reported it (API models do; HF rarely)."""
312- try :
313- usage = response .generations [0 ][0 ].message .usage_metadata
314- if usage and usage .get ("output_tokens" ):
315- return usage ["output_tokens" ]
316- except (AttributeError , IndexError , TypeError ):
317- pass
318- usage = (response .llm_output or {}).get ("token_usage" , {})
319- return usage .get ("completion_tokens" ) or usage .get ("output_tokens" )
320-
321-
322- SPINNER_WORDS = [
323- "Thinking" ,
324- "Pondering" ,
325- "Discombobulating" ,
326- "Cooking" ,
327- "Brewing" ,
328- "Ruminating" ,
329- "Rummaging" ,
330- "Noodling" ,
331- ]
332-
333-
334- class Spinner :
335- """Animated status line shown while work happens in the calling thread."""
336-
337- def __init__ (self ):
338- self ._stop = threading .Event ()
339- self ._thread : Optional [threading .Thread ] = None
340-
341- def __enter__ (self ):
342- if sys .stdout .isatty ():
343- self ._thread = threading .Thread (target = self ._spin , daemon = True )
344- self ._thread .start ()
345- return self
346-
347- def __exit__ (self , * exc ):
348- self ._stop .set ()
349- if self ._thread is not None :
350- self ._thread .join ()
351- sys .stdout .write ("\r \033 [K" )
352- sys .stdout .flush ()
353-
354- def _spin (self ):
355- frames = itertools .cycle ("|/-\\ " )
356- word = random .choice (SPINNER_WORDS )
357- start = word_start = time .monotonic ()
358- while not self ._stop .is_set ():
359- now = time .monotonic ()
360- if now - word_start > 3 :
361- word = random .choice (SPINNER_WORDS )
362- word_start = now
363- status = f"{ next (frames )} { word } ... ({ int (now - start )} s)"
364- sys .stdout .write (f"\r \033 [K{ str_in_color (status , 'blue' )} " )
365- sys .stdout .flush ()
366- time .sleep (0.1 )
367-
368-
369- def quiet_noisy_libs ():
370- """Hide INFO logs, warnings and progress bars so the CLI stays clean."""
371- logging .disable (logging .INFO )
372- warnings .filterwarnings ("ignore" )
373- try :
374- from transformers .utils import logging as hf_logging
375- except ImportError :
376- return
377- hf_logging .set_verbosity_error ()
378- hf_logging .disable_progress_bar ()
379-
380-
381283def is_valid_model_path (model_path : str ):
382284 try :
383285 model_info (model_path )
@@ -389,30 +291,6 @@ def is_valid_model_path(model_path: str):
389291 )
390292
391293
392- def str_in_color (to_print : str | int , color : str , bold : bool = False ) -> str :
393- colors = {
394- "reset" : "\033 [0m" ,
395- "bold" : "\033 [1m" ,
396- "red" : "\033 [31m" ,
397- "green" : "\033 [32m" ,
398- "yellow" : "\033 [33m" ,
399- "blue" : "\033 [34m" ,
400- "gray" : "\033 [90m" ,
401- }
402- style = colors .get (color , colors ["reset" ])
403- if bold :
404- style = colors ["bold" ] + style
405- return f"{ style } { to_print } { colors ['reset' ]} "
406-
407-
408- def print_in_color (to_print : str | int , color : str , bold : bool = False ) -> None :
409- print (str_in_color (to_print , color , bold ))
410-
411-
412- def str_green (text , bold = False ):
413- return str_in_color (text , "green" , bold = bold )
414-
415-
416294if __name__ == "__main__" :
417295 quiet_noisy_libs ()
418296 enable_profiling_from_env ()
0 commit comments