Skip to content

Commit 68dd154

Browse files
committed
refactor: create separate file for rag cli display helpers
1 parent 56bedef commit 68dd154

2 files changed

Lines changed: 154 additions & 130 deletions

File tree

src/mmore/ragcli_console.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Presentation helpers for the RAG CLI: colored output, spinner, noise control,
2+
and timing/token metrics collection."""
3+
4+
import itertools
5+
import logging
6+
import random
7+
import sys
8+
import threading
9+
import time
10+
import warnings
11+
from typing import Any, Dict, Optional
12+
13+
from langchain_core.callbacks import BaseCallbackHandler
14+
15+
# ----------------------------- Colored output ------------------------------ #
16+
17+
18+
def str_in_color(to_print: str | int, color: str, bold: bool = False) -> str:
19+
colors = {
20+
"reset": "\033[0m",
21+
"bold": "\033[1m",
22+
"red": "\033[31m",
23+
"green": "\033[32m",
24+
"yellow": "\033[33m",
25+
"blue": "\033[34m",
26+
"gray": "\033[90m",
27+
}
28+
style = colors.get(color, colors["reset"])
29+
if bold:
30+
style = colors["bold"] + style
31+
return f"{style}{to_print}{colors['reset']}"
32+
33+
34+
def print_in_color(to_print: str | int, color: str, bold: bool = False) -> None:
35+
print(str_in_color(to_print, color, bold))
36+
37+
38+
def str_green(text, bold=False):
39+
return str_in_color(text, "green", bold=bold)
40+
41+
42+
# ------------------------------ Noise control ------------------------------ #
43+
44+
45+
def quiet_noisy_libs():
46+
"""Hide INFO logs, warnings and progress bars so the CLI stays clean."""
47+
logging.disable(logging.INFO)
48+
warnings.filterwarnings("ignore")
49+
try:
50+
from transformers.utils import logging as hf_logging
51+
except ImportError:
52+
return
53+
hf_logging.set_verbosity_error()
54+
hf_logging.disable_progress_bar()
55+
56+
57+
# --------------------------------- Spinner --------------------------------- #
58+
59+
SPINNER_WORDS = [
60+
"Thinking",
61+
"Pondering",
62+
"Discombobulating",
63+
"Cooking",
64+
"Brewing",
65+
"Ruminating",
66+
"Rummaging",
67+
"Noodling",
68+
]
69+
70+
71+
class Spinner:
72+
"""Animated status line shown while work happens in the calling thread."""
73+
74+
def __init__(self):
75+
self._stop = threading.Event()
76+
self._thread: Optional[threading.Thread] = None
77+
78+
def __enter__(self):
79+
if sys.stdout.isatty():
80+
self._thread = threading.Thread(target=self._spin, daemon=True)
81+
self._thread.start()
82+
return self
83+
84+
def __exit__(self, *exc):
85+
self._stop.set()
86+
if self._thread is not None:
87+
self._thread.join()
88+
sys.stdout.write("\r\033[K")
89+
sys.stdout.flush()
90+
91+
def _spin(self):
92+
frames = itertools.cycle("|/-\\")
93+
word = random.choice(SPINNER_WORDS)
94+
start = word_start = time.monotonic()
95+
while not self._stop.is_set():
96+
now = time.monotonic()
97+
if now - word_start > 3:
98+
word = random.choice(SPINNER_WORDS)
99+
word_start = now
100+
status = f"{next(frames)} {word}... ({int(now - start)}s)"
101+
sys.stdout.write(f"\r\033[K{str_in_color(status, 'blue')}")
102+
sys.stdout.flush()
103+
time.sleep(0.1)
104+
105+
106+
# ----------------------------- Timing metrics ------------------------------ #
107+
108+
109+
class TimingHandler(BaseCallbackHandler):
110+
"""Collects retrieval/generation wall times and token usage from callbacks."""
111+
112+
def __init__(self):
113+
self.retrieval_time: Optional[float] = None
114+
self.generation_time: Optional[float] = None
115+
self.completion_tokens: Optional[int] = None
116+
self._starts: Dict[Any, float] = {}
117+
118+
def on_retriever_start(self, serialized, query, *, run_id, **kwargs):
119+
self._starts[run_id] = time.perf_counter()
120+
121+
def on_retriever_end(self, documents, *, run_id, **kwargs):
122+
if run_id in self._starts:
123+
self.retrieval_time = time.perf_counter() - self._starts.pop(run_id)
124+
125+
def on_llm_start(self, serialized, prompts, *, run_id, **kwargs):
126+
self._starts[run_id] = time.perf_counter()
127+
128+
def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
129+
self._starts[run_id] = time.perf_counter()
130+
131+
def on_llm_end(self, response, *, run_id, **kwargs):
132+
if run_id in self._starts:
133+
self.generation_time = time.perf_counter() - self._starts.pop(run_id)
134+
self.completion_tokens = _output_tokens(response)
135+
136+
137+
def _output_tokens(response) -> Optional[int]:
138+
"""Generated-token count if the provider reported it (API models do; HF rarely)."""
139+
try:
140+
usage = response.generations[0][0].message.usage_metadata
141+
if usage and usage.get("output_tokens"):
142+
return usage["output_tokens"]
143+
except (AttributeError, IndexError, TypeError):
144+
pass
145+
usage = (response.llm_output or {}).get("token_usage", {})
146+
return usage.get("completion_tokens") or usage.get("output_tokens")

src/mmore/run_ragcli.py

Lines changed: 8 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
11
import argparse
2-
import itertools
32
import logging
4-
import random
5-
import sys
6-
import threading
7-
import time
8-
import warnings
93
from typing import Any, Dict, List, Optional
104

115
from huggingface_hub import model_info
126
from huggingface_hub.errors import HfHubHTTPError
13-
from langchain_core.callbacks import BaseCallbackHandler
147
from pymilvus.exceptions import MilvusException
158

169
RAG_EMOJI = "🧠🧠🧠🧠🧠"
@@ -23,6 +16,14 @@
2316

2417
from mmore.profiler import enable_profiling_from_env, profile_function
2518
from mmore.rag.pipeline import RAGPipeline
19+
from mmore.ragcli_console import (
20+
Spinner,
21+
TimingHandler,
22+
print_in_color,
23+
quiet_noisy_libs,
24+
str_green,
25+
str_in_color,
26+
)
2627
from mmore.run_rag import RAGInferenceConfig
2728
from mmore.utils import load_config
2829

@@ -279,105 +280,6 @@ def _count_tokens(self, text: Optional[str]) -> Optional[int]:
279280
return None
280281

281282

282-
class TimingHandler(BaseCallbackHandler):
283-
"""Collects retrieval/generation wall times and token usage from callbacks."""
284-
285-
def __init__(self):
286-
self.retrieval_time: Optional[float] = None
287-
self.generation_time: Optional[float] = None
288-
self.completion_tokens: Optional[int] = None
289-
self._starts: Dict[Any, float] = {}
290-
291-
def on_retriever_start(self, serialized, query, *, run_id, **kwargs):
292-
self._starts[run_id] = time.perf_counter()
293-
294-
def on_retriever_end(self, documents, *, run_id, **kwargs):
295-
if run_id in self._starts:
296-
self.retrieval_time = time.perf_counter() - self._starts.pop(run_id)
297-
298-
def on_llm_start(self, serialized, prompts, *, run_id, **kwargs):
299-
self._starts[run_id] = time.perf_counter()
300-
301-
def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
302-
self._starts[run_id] = time.perf_counter()
303-
304-
def on_llm_end(self, response, *, run_id, **kwargs):
305-
if run_id in self._starts:
306-
self.generation_time = time.perf_counter() - self._starts.pop(run_id)
307-
self.completion_tokens = _output_tokens(response)
308-
309-
310-
def _output_tokens(response) -> Optional[int]:
311-
"""Generated-token count if the provider reported it (API models do; HF rarely)."""
312-
try:
313-
usage = response.generations[0][0].message.usage_metadata
314-
if usage and usage.get("output_tokens"):
315-
return usage["output_tokens"]
316-
except (AttributeError, IndexError, TypeError):
317-
pass
318-
usage = (response.llm_output or {}).get("token_usage", {})
319-
return usage.get("completion_tokens") or usage.get("output_tokens")
320-
321-
322-
SPINNER_WORDS = [
323-
"Thinking",
324-
"Pondering",
325-
"Discombobulating",
326-
"Cooking",
327-
"Brewing",
328-
"Ruminating",
329-
"Rummaging",
330-
"Noodling",
331-
]
332-
333-
334-
class Spinner:
335-
"""Animated status line shown while work happens in the calling thread."""
336-
337-
def __init__(self):
338-
self._stop = threading.Event()
339-
self._thread: Optional[threading.Thread] = None
340-
341-
def __enter__(self):
342-
if sys.stdout.isatty():
343-
self._thread = threading.Thread(target=self._spin, daemon=True)
344-
self._thread.start()
345-
return self
346-
347-
def __exit__(self, *exc):
348-
self._stop.set()
349-
if self._thread is not None:
350-
self._thread.join()
351-
sys.stdout.write("\r\033[K")
352-
sys.stdout.flush()
353-
354-
def _spin(self):
355-
frames = itertools.cycle("|/-\\")
356-
word = random.choice(SPINNER_WORDS)
357-
start = word_start = time.monotonic()
358-
while not self._stop.is_set():
359-
now = time.monotonic()
360-
if now - word_start > 3:
361-
word = random.choice(SPINNER_WORDS)
362-
word_start = now
363-
status = f"{next(frames)} {word}... ({int(now - start)}s)"
364-
sys.stdout.write(f"\r\033[K{str_in_color(status, 'blue')}")
365-
sys.stdout.flush()
366-
time.sleep(0.1)
367-
368-
369-
def quiet_noisy_libs():
370-
"""Hide INFO logs, warnings and progress bars so the CLI stays clean."""
371-
logging.disable(logging.INFO)
372-
warnings.filterwarnings("ignore")
373-
try:
374-
from transformers.utils import logging as hf_logging
375-
except ImportError:
376-
return
377-
hf_logging.set_verbosity_error()
378-
hf_logging.disable_progress_bar()
379-
380-
381283
def is_valid_model_path(model_path: str):
382284
try:
383285
model_info(model_path)
@@ -389,30 +291,6 @@ def is_valid_model_path(model_path: str):
389291
)
390292

391293

392-
def str_in_color(to_print: str | int, color: str, bold: bool = False) -> str:
393-
colors = {
394-
"reset": "\033[0m",
395-
"bold": "\033[1m",
396-
"red": "\033[31m",
397-
"green": "\033[32m",
398-
"yellow": "\033[33m",
399-
"blue": "\033[34m",
400-
"gray": "\033[90m",
401-
}
402-
style = colors.get(color, colors["reset"])
403-
if bold:
404-
style = colors["bold"] + style
405-
return f"{style}{to_print}{colors['reset']}"
406-
407-
408-
def print_in_color(to_print: str | int, color: str, bold: bool = False) -> None:
409-
print(str_in_color(to_print, color, bold))
410-
411-
412-
def str_green(text, bold=False):
413-
return str_in_color(text, "green", bold=bold)
414-
415-
416294
if __name__ == "__main__":
417295
quiet_noisy_libs()
418296
enable_profiling_from_env()

0 commit comments

Comments
 (0)