-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathmanage.py
More file actions
105 lines (91 loc) · 3.66 KB
/
Copy pathmanage.py
File metadata and controls
105 lines (91 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import argparse
import uvicorn
import sys
import os
# Fix: Import from the new package structure
from training_pipeline import train_teacher
from training_pipeline import distill_student
def run_server(args):
"""Starts the FastAPI development server."""
print(f"🚀 Starting NeuroRank API on {args.host}:{args.port}...")
uvicorn.run(
"ranker_service.main:app", host=args.host, port=args.port, reload=args.reload
)
def run_teacher(args):
"""Runs the teacher training pipeline."""
print("👨🏫 Starting Teacher (BERT) training...")
train_teacher.run_training(args)
def run_student(args):
"""Runs the student distillation pipeline."""
print("🎓 Starting Student (MiniLM) distillation...")
distill_student.run_distillation(args)
def main():
parser = argparse.ArgumentParser(description="NeuroRank Management Interface")
subparsers = parser.add_subparsers(dest="command", help="Available commands")
# --- Server Command ---
server_parser = subparsers.add_parser("runserver", help="Start the API server")
server_parser.add_argument("--host", default="0.0.0.0", help="Host to bind to")
server_parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
server_parser.add_argument(
"--reload", action="store_true", help="Enable auto-reload for dev"
)
server_parser.set_defaults(func=run_server)
# --- Train Teacher Command ---
teacher_parser = subparsers.add_parser(
"train-teacher", help="Train the teacher model"
)
# Fix: Changed default path from /content/drive/... to ./data
teacher_parser.add_argument(
"--data_dir",
default="./data",
help="Path to the MS MARCO dataset"
)
teacher_parser.add_argument("--model", default="microsoft/MiniLM-L12-H384-uncased")
teacher_parser.add_argument("--epochs", type=int, default=1)
teacher_parser.add_argument("--lr", type=float, default=2e-5)
teacher_parser.add_argument("--batch", type=int, default=16)
teacher_parser.add_argument("--max_len", type=int, default=256)
# Fix: Changed default path to ./models/teacher
teacher_parser.add_argument(
"--out_dir",
default="./models/teacher",
help="Output directory for teacher model"
)
teacher_parser.set_defaults(func=run_teacher)
# --- Train Student Command ---
student_parser = subparsers.add_parser(
"train-student", help="Distill into student model"
)
# Fix: Changed default path to ./data
student_parser.add_argument(
"--data_dir",
default="./data",
help="Path to the MS MARCO dataset"
)
# Fix: Changed default path to relative ./models/teacher/best.pt
student_parser.add_argument(
"--teacher",
default="./models/teacher/best.pt",
help="Path to teacher best.pt"
)
student_parser.add_argument("--student", default="sentence-transformers/all-MiniLM-L6-v2")
student_parser.add_argument("--epochs", type=int, default=1)
student_parser.add_argument("--lr", type=float, default=3e-5)
student_parser.add_argument("--batch", type=int, default=64)
student_parser.add_argument("--max_len", type=int, default=256)
student_parser.add_argument("--temp", type=float, default=3.0)
# Fix: Changed default path to ./models/student
student_parser.add_argument(
"--out_dir",
default="./models/student",
help="Output directory for student model"
)
student_parser.set_defaults(func=run_student)
args = parser.parse_args()
if hasattr(args, "func"):
args.func(args)
else:
parser.print_help()
sys.exit(1)
if __name__ == "__main__":
main()