forked from NVIDIA-AI-Blueprints/rag
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathdemo_model_extractor.py
More file actions
90 lines (71 loc) · 2.59 KB
/
demo_model_extractor.py
File metadata and controls
90 lines (71 loc) · 2.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
"""
Interactive demo for ModelNameExtractor with sentence transformers.
"""
import sys
import logging
# Add the src directory to the path
sys.path.insert(0, 'src')
from apply_configuration import ModelNameExtractor, MODEL_TAGS
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
def main():
"""Run interactive demo."""
print("=" * 80)
print("ModelNameExtractor Interactive Demo")
print("=" * 80)
print()
# Initialize the extractor
print("Initializing ModelNameExtractor...")
extractor = ModelNameExtractor(MODEL_TAGS, similarity_threshold=0.3)
print("✓ Using enhanced keyword matching")
print(f" Configured {len(extractor._tag_keywords)} models")
print()
print("Available models:")
for i, tag in enumerate(MODEL_TAGS, 1):
print(f" {i}. {tag}")
print()
print("Enter queries to find matching models (type 'quit' to exit)")
print("Examples:")
print(" - 'I need llama 3 with 8 billion parameters'")
print(" - 'Deploy mistral for chat'")
print(" - 'What's the best 70B model?'")
print()
while True:
try:
query = input("\nQuery: ").strip()
if query.lower() in ['quit', 'exit', 'q']:
print("Goodbye!")
break
if not query:
continue
# Extract model
result = extractor.extract(query)
if result:
print(f"✓ Matched model: {result}")
# Show match scores for all models
scores = []
for tag in extractor.tags:
score = extractor._calculate_match_score(query, tag)
if score > 0:
scores.append((tag, score))
# Sort by score
scores.sort(key=lambda x: x[1], reverse=True)
if scores:
print("\n Top match scores:")
for i, (tag, score) in enumerate(scores[:3]):
marker = "→" if tag == result else " "
print(f" {marker} {score:.3f} - {tag}")
else:
print("✗ No matching model found")
except KeyboardInterrupt:
print("\nGoodbye!")
break
except Exception as e:
print(f"Error: {e}")
if __name__ == "__main__":
main()