Skip to content

Commit 1d69b76

Browse files
committed
chore: generate span annotations via plpgsql (#7417)
1 parent 7e34714 commit 1d69b76

File tree

2 files changed

+414
-0
lines changed

2 files changed

+414
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Span Annotations Generation Script
4+
5+
This script executes the generate_span_annotations.sql script to create random annotations
6+
for spans in the database. It provides a convenient way to run the SQL script with
7+
configurable parameters.
8+
9+
The script generates random annotations with the following characteristics:
10+
- Randomly sampled spans (approximately 1% of total spans)
11+
- Between 1 and max_annotations_per_span annotations per span
12+
- Randomly assigned names from the provided list
13+
- Randomly assigned labels: "YES" or "NO"
14+
- Random scores between 0 and 1
15+
- Empty metadata JSON objects
16+
- Random annotator kind: "HUMAN" or "LLM"
17+
- Random explanation text
18+
19+
Usage:
20+
python generate_span_annotations.py [options]
21+
22+
Options:
23+
--db-name NAME Database name (default: postgres)
24+
--db-user USER Database user (default: postgres)
25+
--db-host HOST Database host (default: localhost)
26+
--db-port PORT Database port (default: 5432)
27+
--db-password PASS Database password (default: phoenix)
28+
--limit LIMIT Number of spans to sample (default: 10000)
29+
--max-annotations-per-span MAX
30+
Maximum number of annotations per span (default: 10)
31+
--label-missing-prob PROB
32+
Probability of label being missing (default: 0.1)
33+
--score-missing-prob PROB
34+
Probability of score being missing (default: 0.1)
35+
--explanation-missing-prob PROB
36+
Probability of explanation being missing (default: 0.1)
37+
--metadata-missing-prob PROB
38+
Probability of metadata being missing (default: 0.1)
39+
--annotation-names NAMES
40+
Comma-separated list of annotation names (default: correctness,helpfulness,relevance,safety,coherence)
41+
42+
Example:
43+
# Use default parameters
44+
python generate_span_annotations.py
45+
46+
# Specify custom parameters
47+
python generate_span_annotations.py \
48+
--db-name mydb \
49+
--db-user myuser \
50+
--db-host localhost \
51+
--db-port 5432 \
52+
--db-password mypass \
53+
--limit 10000 \
54+
--max-annotations-per-span 10 \
55+
--label-missing-prob 0.1 \
56+
--score-missing-prob 0.1 \
57+
--explanation-missing-prob 0.1 \
58+
--metadata-missing-prob 0.1 \
59+
--annotation-names "correctness,helpfulness,relevance,safety,coherence"
60+
61+
Dependencies:
62+
- Python 3.x
63+
- psql command-line tool
64+
- PostgreSQL database with the following tables:
65+
- public.spans
66+
- public.span_annotations
67+
68+
The script uses a single bulk INSERT operation for efficiency and maintains referential
69+
integrity by using the span's id as span_rowid in the annotations.
70+
""" # noqa: E501
71+
72+
import argparse
73+
import os
74+
import subprocess
75+
import sys
76+
import time
77+
from datetime import timedelta
78+
79+
80+
def parse_arguments():
81+
"""Parse command line arguments.
82+
83+
Returns:
84+
argparse.Namespace: Parsed command line arguments
85+
"""
86+
parser = argparse.ArgumentParser(description="Generate span annotations")
87+
parser.add_argument(
88+
"--db-name",
89+
type=str,
90+
default="postgres",
91+
help="Database name (default: postgres)",
92+
)
93+
parser.add_argument(
94+
"--db-user",
95+
type=str,
96+
default="postgres",
97+
help="Database user (default: postgres)",
98+
)
99+
parser.add_argument(
100+
"--db-host",
101+
type=str,
102+
default="localhost",
103+
help="Database host (default: localhost)",
104+
)
105+
parser.add_argument(
106+
"--db-port",
107+
type=int,
108+
default=5432,
109+
help="Database port (default: 5432)",
110+
)
111+
parser.add_argument(
112+
"--db-password",
113+
type=str,
114+
default="phoenix",
115+
help="Database password (default: phoenix)",
116+
)
117+
parser.add_argument(
118+
"--limit",
119+
type=int,
120+
default=10_000,
121+
help="Number of spans to sample (default: 10000)",
122+
)
123+
parser.add_argument(
124+
"--max-annotations-per-span",
125+
type=int,
126+
default=10,
127+
help="Maximum number of annotations per span (default: 10)",
128+
)
129+
parser.add_argument(
130+
"--label-missing-prob",
131+
type=float,
132+
default=0.1,
133+
help="Probability of label being missing (default: 0.1)",
134+
)
135+
parser.add_argument(
136+
"--score-missing-prob",
137+
type=float,
138+
default=0.1,
139+
help="Probability of score being missing (default: 0.1)",
140+
)
141+
parser.add_argument(
142+
"--explanation-missing-prob",
143+
type=float,
144+
default=0.1,
145+
help="Probability of explanation being missing (default: 0.1)",
146+
)
147+
parser.add_argument(
148+
"--metadata-missing-prob",
149+
type=float,
150+
default=0.1,
151+
help="Probability of metadata being missing (default: 0.1)",
152+
)
153+
parser.add_argument(
154+
"--annotation-names",
155+
type=str,
156+
default="correctness,helpfulness,relevance,safety,coherence,note",
157+
help="Comma-separated list of annotation names (default: correctness,helpfulness,relevance,safety,coherence,note)", # noqa: E501
158+
)
159+
return parser.parse_args()
160+
161+
162+
def run_sql_script(
163+
db_name,
164+
db_user,
165+
db_host,
166+
db_port,
167+
db_password,
168+
script_path,
169+
print_output=True,
170+
limit=10000,
171+
max_annotations_per_span=10,
172+
label_missing_prob=0.1,
173+
score_missing_prob=0.1,
174+
explanation_missing_prob=0.1,
175+
metadata_missing_prob=0.1,
176+
annotation_names="correctness,helpfulness,relevance,safety,coherence",
177+
):
178+
"""Run a SQL script file using psql.
179+
180+
Args:
181+
db_name (str): Database name
182+
db_user (str): Database user
183+
db_host (str): Database host
184+
db_port (int): Database port
185+
db_password (str): Database password
186+
script_path (str): Path to SQL script file
187+
print_output (bool): Whether to print the output (default: True)
188+
limit (int): Number of spans to sample and annotate (default: 10000)
189+
max_annotations_per_span (int): Maximum number of annotations per span (default: 10)
190+
label_missing_prob (float): Probability of label being missing (default: 0.1)
191+
score_missing_prob (float): Probability of score being missing (default: 0.1)
192+
explanation_missing_prob (float): Probability of explanation being missing (default: 0.1)
193+
metadata_missing_prob (float): Probability of metadata being missing (default: 0.1)
194+
annotation_names (str): Comma-separated list of annotation names (default: correctness,helpfulness,relevance,safety,coherence)
195+
196+
Returns:
197+
bool: True if successful, False otherwise
198+
199+
Raises:
200+
subprocess.CalledProcessError: If the psql command fails
201+
""" # noqa: E501
202+
# Set up environment with password
203+
env = os.environ.copy()
204+
env["PGPASSWORD"] = db_password
205+
206+
# Escape single quotes in annotation names
207+
escaped_names = annotation_names.replace("'", "''")
208+
cmd = [
209+
"psql",
210+
"-h",
211+
db_host,
212+
"-p",
213+
str(db_port),
214+
"-d",
215+
db_name,
216+
"-U",
217+
db_user,
218+
"-v",
219+
f"limit={limit}",
220+
"-v",
221+
f"max_annotations_per_span={max_annotations_per_span}",
222+
"-v",
223+
f"label_missing_prob={label_missing_prob}",
224+
"-v",
225+
f"score_missing_prob={score_missing_prob}",
226+
"-v",
227+
f"explanation_missing_prob={explanation_missing_prob}",
228+
"-v",
229+
f"metadata_missing_prob={metadata_missing_prob}",
230+
"-v",
231+
f"annotation_names={escaped_names}",
232+
"-f",
233+
script_path,
234+
]
235+
236+
# Execute the command
237+
result = subprocess.run(cmd, capture_output=True, text=True, env=env)
238+
239+
# Check if the command was successful
240+
if result.returncode != 0:
241+
print("Error executing SQL script:")
242+
print(result.stderr)
243+
return False
244+
245+
# Print the output if requested
246+
if print_output and result.stdout:
247+
print("\nSQL Output:")
248+
print(result.stdout)
249+
if result.stderr:
250+
print("\nSQL Errors:")
251+
print(result.stderr)
252+
253+
return True
254+
255+
256+
def main():
257+
"""Main function to execute the span annotations generation.
258+
259+
This function:
260+
1. Parses command line arguments
261+
2. Locates the SQL script
262+
3. Executes the script with the provided database connection parameters
263+
4. Reports success or failure with timing information
264+
"""
265+
args = parse_arguments()
266+
267+
# Get the directory of the current script
268+
script_dir = os.path.dirname(os.path.abspath(__file__))
269+
270+
# Hard-coded script paths
271+
sql_script_path = os.path.join(script_dir, "generate_span_annotations.sql")
272+
273+
try:
274+
print("Generating span annotations...", end="", flush=True)
275+
276+
# Record start time
277+
start_time = time.time()
278+
279+
if not run_sql_script(
280+
args.db_name,
281+
args.db_user,
282+
args.db_host,
283+
args.db_port,
284+
args.db_password,
285+
sql_script_path,
286+
limit=args.limit,
287+
max_annotations_per_span=args.max_annotations_per_span,
288+
label_missing_prob=args.label_missing_prob,
289+
score_missing_prob=args.score_missing_prob,
290+
explanation_missing_prob=args.explanation_missing_prob,
291+
metadata_missing_prob=args.metadata_missing_prob,
292+
annotation_names=args.annotation_names,
293+
):
294+
print(" failed")
295+
print("Error generating annotations. Aborting.")
296+
sys.exit(1)
297+
298+
# Report completion
299+
total_time = time.time() - start_time
300+
total_time_str = str(timedelta(seconds=int(total_time)))
301+
print(f" done (took {total_time_str})")
302+
303+
except Exception as e:
304+
print(f"Error: {e}")
305+
sys.exit(1)
306+
307+
308+
if __name__ == "__main__":
309+
main()

0 commit comments

Comments
 (0)