Skip to content

Commit 56bedef

Browse files
committed
feat: correct RAG CLI empty answers and improve UX
1 parent 6281137 commit 56bedef

5 files changed

Lines changed: 299 additions & 29 deletions

File tree

src/mmore/cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import os
2+
3+
# To have the fork-handler messages ("FD from fork parent still in poll list")
4+
os.environ.setdefault("GRPC_VERBOSITY", "ERROR")
5+
16
from typing import Optional
27

38
import click

src/mmore/rag/pipeline.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@
1111
from langchain_core.language_models.chat_models import BaseChatModel
1212
from langchain_core.output_parsers import StrOutputParser
1313
from langchain_core.prompts import ChatPromptTemplate
14-
from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough
14+
from langchain_core.runnables import (
15+
Runnable,
16+
RunnableConfig,
17+
RunnableLambda,
18+
RunnablePassthrough,
19+
)
1520

1621
from ..utils import load_config
1722
from .judge import JUDGE_OUTPUT_KEYS, JudgeConfig, LLMJudge, retrieve_with_judge
@@ -130,14 +135,17 @@ def retrieval_with_judge(state: Dict[str, Any]) -> Dict[str, Any]:
130135
return validate_input | core_chain | validate_output
131136

132137
def __call__(
133-
self, queries: Dict[str, Any] | List[Dict[str, Any]], return_dict: bool = False
134-
) -> List[Dict[str, str | List[str]]]:
138+
self,
139+
queries: Dict[str, Any] | List[Dict[str, Any]],
140+
return_dict: bool = False,
141+
config: Optional[RunnableConfig] = None,
142+
) -> List[Dict[str, Any]]:
135143
if isinstance(queries, Dict):
136144
queries_list = [queries]
137145
else:
138146
queries_list = queries
139147

140-
results = self.rag_chain.batch(queries_list)
148+
results = self.rag_chain.batch(queries_list, config=config)
141149

142150
if return_dict:
143151
return results

src/mmore/run_ragcli.py

Lines changed: 205 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
import argparse
2+
import itertools
23
import logging
3-
from typing import Optional
4+
import random
5+
import sys
6+
import threading
7+
import time
8+
import warnings
9+
from typing import Any, Dict, List, Optional
410

511
from huggingface_hub import model_info
612
from huggingface_hub.errors import HfHubHTTPError
13+
from langchain_core.callbacks import BaseCallbackHandler
14+
from pymilvus.exceptions import MilvusException
715

816
RAG_EMOJI = "🧠🧠🧠🧠🧠"
917
logger = logging.getLogger(__name__)
@@ -29,27 +37,22 @@ def __init__(self, config_file: str):
2937
self.config_file = config_file
3038

3139
def launch_cli(self):
40+
quiet_noisy_libs()
3241
print_in_color(
3342
"Welcome to this RAG command-line interface! 🧠", "green", bold=True
3443
)
3544
print(
36-
"Available commands are: config, rag, setK, setModel, setWebrag, exit, help. To learn more about usage of a specific command, use the following: \n help <command>"
45+
f"\nPress {str_green('Enter', bold=True)} to start asking questions about your documents.\n"
3746
)
3847
print(
39-
f"Available commands:\n\
48+
f"Other commands:\n\
4049
{str_green('config')} : see the current config \n\
41-
{str_green('rag')} : enter the RAG CLI \n\
4250
{str_green('setK')} : set the number of documents to retrieve \n\
4351
{str_green('setModel')} : set the model for generation \n\
4452
{str_green('setWebrag')} : decide whether to use web rag \n\
45-
{str_green('help')} : learn more about a command \n\
53+
{str_green('help')} : learn more about a command (help <command>) \n\
4654
{str_green('exit')} : exit the CLI"
4755
)
48-
print_in_color(
49-
"To learn more about usage of a specific command, use the following: \n help <command>",
50-
"blue",
51-
bold=True,
52-
)
5356
while True:
5457
try:
5558
cmd = input("> ").strip()
@@ -58,7 +61,7 @@ def launch_cli(self):
5861
break
5962
elif cmd == "help":
6063
print(
61-
"Available commands are: config, rag, setK, setModel, webrag, exit, help. To learn more about usage of a specific command, use the following: \n help <command>"
64+
f"Press {str_green('Enter')} (or type rag) to start asking questions about your documents.\nOther commands are: config, setK, setModel, setWebrag, exit, help. To learn more about usage of a specific command, use the following: \n help <command>"
6265
)
6366
elif cmd.startswith("help "):
6467
command = cmd.split(" ", 1)[1]
@@ -69,7 +72,9 @@ def launch_cli(self):
6972
elif command == "config":
7073
print("Print the current configuration.")
7174
elif command == "rag":
72-
print("Enter the RAG CLI. Type /bye to exit.")
75+
print(
76+
"Start a chat session to ask questions about your documents. Type /bye to exit."
77+
)
7378
elif command == "setK":
7479
print(
7580
"Use the command in the following way: 'setK <k>', for a positive integer k. This will set the number of documents to retrieve during RAG."
@@ -147,27 +152,42 @@ def launch_cli(self):
147152
f"Invalid output. Enter {str_in_color('setWebrag True', 'green')} or {str_in_color('setWebrag False', 'red')}."
148153
)
149154

150-
elif cmd == "rag":
155+
elif cmd in ("", "rag"):
151156
self.cli_ception()
152157

153158
else:
154159
print(f"Unknown command: {cmd}")
160+
if " " in cmd or cmd.endswith("?"):
161+
print(
162+
f"Looks like a question! Press {str_green('Enter')} first to start asking questions about your documents."
163+
)
155164
except (EOFError, KeyboardInterrupt):
156165
print("\nExiting...")
157166
break
158167

159168
def cli_ception(self):
169+
self.init_config()
170+
if self.ragPP is None or self.modified:
171+
try:
172+
with Spinner():
173+
self.initialize_ragpp()
174+
except MilvusException as e:
175+
print_in_color(f"Failed to open the document database: {e}", "red")
176+
print(
177+
f"A previous session may still be holding it. Run {str_green('pkill -f milvus_lite/lib/milvus')} and try again."
178+
)
179+
return
180+
self.modified = False
181+
print_in_color("RAG pipeline ready! Ask your questions.", "green")
160182
while True:
161183
query = input(str_in_color("rag (type /bye to exit) > ", "red", bold=True))
162184
if query == "/bye":
163185
print_in_color("Exiting the RAG CLI", "red", True)
164186
break
165187
else:
166-
self.init_config()
167-
if self.ragPP is None or self.modified:
168-
self.initialize_ragpp()
169-
self.modified = False
170-
self.do_rag(query)
188+
with Spinner():
189+
results, timings = self.do_rag(query)
190+
self.print_answer(results, timings)
171191

172192
def init_config(self):
173193
if self.ragConfig is None:
@@ -181,22 +201,181 @@ def initialize_ragpp(self):
181201
logger.info("RAG pipeline initialized!")
182202

183203
@profile_function()
184-
def do_rag(self, query):
204+
def do_rag(self, query) -> tuple[List[Dict[str, Any]], "TimingHandler"]:
185205
queries = [{"input": query, "collection_name": "my_docs"}]
186206
# called only after init_config and initialize_ragpp
187207
assert self.ragConfig is not None
188208
assert self.ragPP is not None
189209

190-
results = self.ragPP(queries, return_dict=True)
210+
timings = TimingHandler()
211+
results = self.ragPP(queries, return_dict=True, config={"callbacks": [timings]})
212+
return results, timings
213+
214+
def print_answer(
215+
self, results: List[Dict[str, Any]], timings: Optional["TimingHandler"] = None
216+
) -> None:
217+
assert self.ragConfig is not None
191218

192-
print(query)
193-
print(results[0]["answer"][-1].split("<|end_header_id|>")[-1])
219+
answer = results[0]["answer"].split("<|end_header_id|>")[-1].strip()
220+
print(f"\n{answer}\n")
221+
if timings is not None:
222+
self._print_metrics(results[0], answer, timings)
194223
if self.ragConfig.rag.retriever.use_web:
195-
print("\nSources: \n")
224+
print("Sources:")
196225
for i in range(self.ragConfig.rag.retriever.k):
197226
url = results[0]["docs"][i]["metadata"]["url"] # pyright: ignore
198227
title = results[0]["docs"][i]["metadata"]["title"] # pyright: ignore
199-
print(f"{title} : {url}")
228+
print(f" - {title}: {url}")
229+
print()
230+
231+
def _print_metrics(
232+
self, result: Dict[str, Any], answer: str, timings: "TimingHandler"
233+
) -> None:
234+
assert self.ragConfig is not None
235+
llm = self.ragConfig.rag.llm
236+
237+
line1 = [f"{llm.llm_name} ({'local' if llm.provider == 'HF' else 'API'})"]
238+
if timings.retrieval_time is not None:
239+
line1.append(f"retrieval {timings.retrieval_time:.2f}s")
240+
if timings.generation_time is not None:
241+
line1.append(f"generation {timings.generation_time:.2f}s")
242+
243+
docs = result.get("docs") or []
244+
line2 = [f"{len(docs)} chunks"]
245+
246+
ctx_tokens = self._count_tokens(result.get("context"))
247+
if ctx_tokens is not None:
248+
ctx = f"{ctx_tokens / 1000:.1f}k" if ctx_tokens >= 1000 else str(ctx_tokens)
249+
line2.append(f"{ctx} context tokens")
250+
251+
gen_tokens = timings.completion_tokens or self._count_tokens(answer)
252+
if gen_tokens:
253+
part = f"{gen_tokens} tokens"
254+
if timings.generation_time:
255+
part += f" @ {gen_tokens / timings.generation_time:.0f} tok/s"
256+
line2.append(part)
257+
258+
scores = [
259+
d["metadata"]["similarity"]
260+
for d in docs
261+
if d.get("metadata", {}).get("similarity") is not None
262+
]
263+
if scores:
264+
line2.append(f"top score {max(scores):.2f}")
265+
266+
print(str_in_color(" | ".join(line1), "gray"))
267+
print(str_in_color(" | ".join(line2), "gray") + "\n")
268+
269+
def _count_tokens(self, text: Optional[str]) -> Optional[int]:
270+
"""Token count using the local model tokenizer, if one is available."""
271+
if not text or self.ragPP is None:
272+
return None
273+
tokenizer = getattr(self.ragPP.llm, "tokenizer", None)
274+
if tokenizer is None:
275+
return None
276+
try:
277+
return len(tokenizer.encode(text))
278+
except Exception:
279+
return None
280+
281+
282+
class TimingHandler(BaseCallbackHandler):
283+
"""Collects retrieval/generation wall times and token usage from callbacks."""
284+
285+
def __init__(self):
286+
self.retrieval_time: Optional[float] = None
287+
self.generation_time: Optional[float] = None
288+
self.completion_tokens: Optional[int] = None
289+
self._starts: Dict[Any, float] = {}
290+
291+
def on_retriever_start(self, serialized, query, *, run_id, **kwargs):
292+
self._starts[run_id] = time.perf_counter()
293+
294+
def on_retriever_end(self, documents, *, run_id, **kwargs):
295+
if run_id in self._starts:
296+
self.retrieval_time = time.perf_counter() - self._starts.pop(run_id)
297+
298+
def on_llm_start(self, serialized, prompts, *, run_id, **kwargs):
299+
self._starts[run_id] = time.perf_counter()
300+
301+
def on_chat_model_start(self, serialized, messages, *, run_id, **kwargs):
302+
self._starts[run_id] = time.perf_counter()
303+
304+
def on_llm_end(self, response, *, run_id, **kwargs):
305+
if run_id in self._starts:
306+
self.generation_time = time.perf_counter() - self._starts.pop(run_id)
307+
self.completion_tokens = _output_tokens(response)
308+
309+
310+
def _output_tokens(response) -> Optional[int]:
311+
"""Generated-token count if the provider reported it (API models do; HF rarely)."""
312+
try:
313+
usage = response.generations[0][0].message.usage_metadata
314+
if usage and usage.get("output_tokens"):
315+
return usage["output_tokens"]
316+
except (AttributeError, IndexError, TypeError):
317+
pass
318+
usage = (response.llm_output or {}).get("token_usage", {})
319+
return usage.get("completion_tokens") or usage.get("output_tokens")
320+
321+
322+
SPINNER_WORDS = [
323+
"Thinking",
324+
"Pondering",
325+
"Discombobulating",
326+
"Cooking",
327+
"Brewing",
328+
"Ruminating",
329+
"Rummaging",
330+
"Noodling",
331+
]
332+
333+
334+
class Spinner:
335+
"""Animated status line shown while work happens in the calling thread."""
336+
337+
def __init__(self):
338+
self._stop = threading.Event()
339+
self._thread: Optional[threading.Thread] = None
340+
341+
def __enter__(self):
342+
if sys.stdout.isatty():
343+
self._thread = threading.Thread(target=self._spin, daemon=True)
344+
self._thread.start()
345+
return self
346+
347+
def __exit__(self, *exc):
348+
self._stop.set()
349+
if self._thread is not None:
350+
self._thread.join()
351+
sys.stdout.write("\r\033[K")
352+
sys.stdout.flush()
353+
354+
def _spin(self):
355+
frames = itertools.cycle("|/-\\")
356+
word = random.choice(SPINNER_WORDS)
357+
start = word_start = time.monotonic()
358+
while not self._stop.is_set():
359+
now = time.monotonic()
360+
if now - word_start > 3:
361+
word = random.choice(SPINNER_WORDS)
362+
word_start = now
363+
status = f"{next(frames)} {word}... ({int(now - start)}s)"
364+
sys.stdout.write(f"\r\033[K{str_in_color(status, 'blue')}")
365+
sys.stdout.flush()
366+
time.sleep(0.1)
367+
368+
369+
def quiet_noisy_libs():
370+
"""Hide INFO logs, warnings and progress bars so the CLI stays clean."""
371+
logging.disable(logging.INFO)
372+
warnings.filterwarnings("ignore")
373+
try:
374+
from transformers.utils import logging as hf_logging
375+
except ImportError:
376+
return
377+
hf_logging.set_verbosity_error()
378+
hf_logging.disable_progress_bar()
200379

201380

202381
def is_valid_model_path(model_path: str):
@@ -218,6 +397,7 @@ def str_in_color(to_print: str | int, color: str, bold: bool = False) -> str:
218397
"green": "\033[32m",
219398
"yellow": "\033[33m",
220399
"blue": "\033[34m",
400+
"gray": "\033[90m",
221401
}
222402
style = colors.get(color, colors["reset"])
223403
if bold:
@@ -234,6 +414,7 @@ def str_green(text, bold=False):
234414

235415

236416
if __name__ == "__main__":
417+
quiet_noisy_libs()
237418
enable_profiling_from_env()
238419
# example usage: python -m mmore.ragcli --config-file examples/rag/config.yaml
239420

src/mmore/tui/app.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def _run_single_command() -> None:
145145
if spec.needs_input_data:
146146
input_data = questionary.text(
147147
"Input JSONL path",
148-
default=cwd_default("outputs/process/merged/merged_results.jsonl"),
148+
default=cwd_default("examples/process/outputs/merged/merged_results.jsonl"),
149149
style=QSTYLE,
150150
qmark=QMARK,
151151
).ask()

0 commit comments

Comments
 (0)