-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathexperiment_runner.py
More file actions
116 lines (93 loc) · 3.79 KB
/
experiment_runner.py
File metadata and controls
116 lines (93 loc) · 3.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import asyncio
import json
import logging
import datetime
from dotenv import load_dotenv
from app import process_chat_message_sync
from galileo.datasets import get_dataset
from galileo.experiments import run_experiment
from galileo import galileo_context
# Configure logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Load environment variables
load_dotenv()
os.environ["GALILEO_API_KEY"] = st.secrets["galileo_api_key"]
os.environ["GALILEO_PROJECT"] = st.secrets["galileo_project"]
os.environ["GALILEO_PROJECT_NAME"] = st.secrets["galileo_project"]
os.environ["GALILEO_LOG_STREAM_NAME"] = st.secrets["galileo_log_stream"]
os.environ["GALILEO_CONSOLE_URL"] = st.secrets["galileo_console_url"]
def process_trade_prompt(example):
"""
Custom function to process a trade prompt using process_chat_message.
This is an async wrapper around our process_chat_message function.
Args:
example: A dataset example with 'input' field
Returns:
The model's response to the prompt
"""
model = "gpt-4"
system_prompt = """You are a stock market analyst and trading assistant. You help users analyze stocks and execute trades."""
message_history = []
try:
print("Processing prompt: ", example)
result = process_chat_message_sync(
prompt=example,
message_history=message_history,
model=model,
system_prompt=system_prompt,
use_rag=True,
namespace="sp500-qa-demo",
top_k=3,
galileo_logger=galileo_context.get_logger_instance(),
ambiguous_tool_names=True,
is_streamlit=False
)
except Exception as e:
logger.error(f"Error processing prompt: {e}")
return {"response": f"Error: {str(e)}", "metadata": {"error": str(e)}}
# Extract and return the response content
response = result["response_message"].content
# Add metadata about tool usage if available
metadata = {}
if result.get("tool_results"):
metadata["tools_used"] = [tool["name"] for tool in result["tool_results"]]
if result.get("rag_documents"):
metadata["rag_documents_count"] = len(result["rag_documents"])
# Log the metadata separately
logger.info(f"Response metadata: {json.dumps(metadata)}")
return response # Return just the response string for Galileo logging
def main():
# Ensure required environment variables are set
required_vars = ["GALILEO_API_KEY", "GALILEO_PROJECT_NAME", "OPENAI_API_KEY"]
missing_vars = [var for var in required_vars if not os.getenv(var)]
if missing_vars:
logger.error(f"Missing required environment variables: {', '.join(missing_vars)}")
return
# Get project name from environment
galileo_project = os.getenv("GALILEO_PROJECT_NAME")
try:
# Get the dataset
try:
dataset = get_dataset(name="trades")
except Exception as e:
logger.error(f"Error getting dataset 'trades': {e}")
return
# Create a unique experiment name with timestamp
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
experiment_name = f"trade_tool_selection_{timestamp}"
logger.info(f"Starting experiment: {experiment_name}")
# Run the experiment with our custom function
results = run_experiment(
experiment_name,
dataset=dataset,
function=process_trade_prompt,
metrics=["correctness"],
project=galileo_project
)
logger.info(f"Experiment completed: {experiment_name}")
except Exception as e:
logger.error(f"Error running experiment: {e}")
if __name__ == "__main__":
main()