Skip to content
This repository was archived by the owner on Apr 30, 2026. It is now read-only.

Commit 37ef74b

Browse files
committed
feat: adding driver script to src/sdg/
Signed-off-by: eshwarprasadS <eshwarprasad.s01@gmail.com>
1 parent 3252d3f commit 37ef74b

1 file changed

Lines changed: 212 additions & 0 deletions

File tree

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Standard
2+
from datetime import datetime
3+
from multiprocessing import set_start_method
4+
import argparse
5+
import logging
6+
import os
7+
import sys
8+
9+
# First Party
10+
from instructlab.sdg.subset_selection import subset_datasets
11+
12+
13+
def setup_logging(log_level="INFO", log_file=None):
14+
"""Set up logging configuration."""
15+
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
16+
17+
numeric_level = getattr(logging, log_level.upper(), None)
18+
if not isinstance(numeric_level, int):
19+
raise ValueError(f"Invalid log level: {log_level}")
20+
21+
# Reset the root logger by removing all handlers
22+
root_logger = logging.getLogger()
23+
for handler in root_logger.handlers[:]:
24+
root_logger.removeHandler(handler)
25+
26+
# Set root logger to INFO by default (to avoid excessive logging from third-party libraries)
27+
root_logger.setLevel(logging.INFO)
28+
29+
# Create formatter
30+
formatter = logging.Formatter(log_format)
31+
32+
# Add console handler
33+
console_handler = logging.StreamHandler(sys.stdout)
34+
console_handler.setFormatter(formatter)
35+
root_logger.addHandler(console_handler)
36+
37+
if log_file:
38+
os.makedirs(os.path.dirname(log_file), exist_ok=True)
39+
file_handler = logging.FileHandler(log_file)
40+
file_handler.setFormatter(formatter)
41+
root_logger.addHandler(file_handler)
42+
43+
app_loggers = [
44+
"instructlab",
45+
"scripts",
46+
"__main__",
47+
]
48+
49+
for logger_name in app_loggers:
50+
logging.getLogger(logger_name).setLevel(numeric_level)
51+
52+
# suppress noisy libraries
53+
noisy_loggers = ["matplotlib", "PIL", "submodlib", "transformers", "torch", "numpy"]
54+
55+
for logger_name in noisy_loggers:
56+
if logger_name in logging.root.manager.loggerDict:
57+
logging.getLogger(logger_name).setLevel(logging.WARNING)
58+
59+
return logging.getLogger(__name__)
60+
61+
62+
def parse_size(value):
63+
"""Parse a size value that can be either a float (percentage) or an int (absolute count)."""
64+
try:
65+
float_value = float(value)
66+
if float_value.is_integer():
67+
return int(float_value)
68+
return float_value
69+
except ValueError as exc:
70+
raise argparse.ArgumentTypeError(
71+
f"Invalid size value: {value}. Must be a number."
72+
) from exc
73+
74+
75+
def parse_args():
76+
parser = argparse.ArgumentParser(
77+
description="Run subset selection on datasets using facility location method.",
78+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
79+
)
80+
81+
# Required arguments
82+
parser.add_argument(
83+
"--input_files",
84+
nargs="+",
85+
required=True,
86+
help="One or more input files (space-separated) to process",
87+
)
88+
89+
parser.add_argument(
90+
"--output_dir", required=True, help="Directory where output files will be saved"
91+
)
92+
93+
parser.add_argument(
94+
"--subset_sizes",
95+
nargs="+",
96+
type=parse_size,
97+
required=True,
98+
help="One or more subset sizes (space-separated). Values between 0-1 represent percentages, integers represent absolute counts",
99+
)
100+
101+
# Optional arguments with defaults
102+
parser.add_argument(
103+
"--num_folds",
104+
type=int,
105+
default=50,
106+
help="Number of folds to use for subset selection (launches separate processes for each fold)",
107+
)
108+
109+
parser.add_argument(
110+
"--batch_size",
111+
type=int,
112+
default=100000,
113+
help="Batch size for processing embeddings",
114+
)
115+
116+
parser.add_argument(
117+
"--num_gpus",
118+
type=int,
119+
default=None,
120+
help="Number of GPUs to use. If not specified, all available GPUs will be used, if specified more than available, max available will be used",
121+
)
122+
123+
parser.add_argument(
124+
"--encoder_type",
125+
default="arctic",
126+
help="Type of encoder to use for generating embeddings",
127+
)
128+
129+
parser.add_argument(
130+
"--encoder_model",
131+
default="Snowflake/snowflake-arctic-embed-l-v2.0",
132+
help="Model to use for generating embeddings, please download using ilab model download prior to using",
133+
)
134+
135+
parser.add_argument(
136+
"--epsilon",
137+
type=float,
138+
default=160.0,
139+
help="Epsilon parameter for the LazierThanLazyGreedy optimizer. Default is optimized for datasets >100k samples. For smaller datasets, use smaller values (starting from 0.1)",
140+
)
141+
142+
parser.add_argument(
143+
"--template_name",
144+
default="conversation",
145+
help="Template name to use for formatting examples. Options: default, conversation, qa",
146+
)
147+
148+
parser.add_argument(
149+
"--testing_mode",
150+
action="store_true",
151+
help="Run in testing mode (limited computation), not for actual use",
152+
)
153+
154+
parser.add_argument(
155+
"--combine_files",
156+
action="store_true",
157+
help="Combine all input files into a single dataset",
158+
)
159+
160+
# Logging arguments
161+
parser.add_argument(
162+
"--log_dir",
163+
default=None,
164+
help="Directory to store log files. If not specified, logs will only be printed to console",
165+
)
166+
167+
parser.add_argument(
168+
"--log_level",
169+
default="INFO",
170+
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
171+
help="Set the logging level",
172+
)
173+
174+
return parser.parse_args()
175+
176+
177+
if __name__ == "__main__":
178+
set_start_method("spawn")
179+
args = parse_args()
180+
181+
# Setup logging
182+
output_log_file = None
183+
if args.log_dir:
184+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
185+
log_filename = f"subset_selection_{timestamp}.log"
186+
output_log_file = os.path.join(args.log_dir, log_filename)
187+
os.makedirs(args.log_dir, exist_ok=True)
188+
189+
logger = setup_logging(args.log_level, output_log_file)
190+
logger.info(f"Starting subset selection with arguments: {args}")
191+
192+
kwargs = vars(args)
193+
194+
kwargs.pop("log_dir", None)
195+
kwargs.pop("log_level", None)
196+
197+
input_files = kwargs.pop("input_files")
198+
subset_sizes = kwargs.pop("subset_sizes")
199+
200+
kwargs = {k: v for k, v in kwargs.items() if v is not None}
201+
202+
try:
203+
# Run subset selection
204+
logger.info(
205+
f"Running subset selection on {input_files} with sizes {subset_sizes}"
206+
)
207+
subset_datasets(input_files=input_files, subset_sizes=subset_sizes, **kwargs)
208+
logger.info("Subset selection completed successfully")
209+
# pylint: disable=broad-exception-caught
210+
except Exception as e:
211+
logger.exception(f"Error during subset selection: {e}")
212+
sys.exit(1)

0 commit comments

Comments
 (0)