Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,6 @@ research_dir/*
state_saves/*
__pycache__/*
Figure*.png
testrun.py
testrun.py
data/*
projects/*
62 changes: 53 additions & 9 deletions ai_lab_repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ def data_preparation(self):
if self.verbose: print("#"*40, f"\nThe following is dialogue produced by the SW Engineer: {dialogue}", "\n", "#"*40)
if "```SUBMIT_CODE" in resp:
final_code = extract_prompt(resp, "SUBMIT_CODE")
code_resp = execute_code(final_code, timeout=60)
code_resp = execute_code(final_code)
if self.verbose: print("!"*100, "\n", f"CODE RESPONSE: {code_resp}")
swe_feedback += f"\nCode Response: {code_resp}\n"
if "[CODE EXECUTION ERROR]" in code_resp:
Expand Down Expand Up @@ -389,7 +389,7 @@ def data_preparation(self):
if "```python" in resp:
code = extract_prompt(resp, "python")
code = self.ml_engineer.dataset_code + "\n" + code
code_resp = execute_code(code, timeout=120)
code_resp = execute_code(code)
ml_command = f"Code produced by the ML agent:\n{code}"
ml_feedback += f"\nCode Response: {code_resp}\n"
if self.verbose: print("!"*100, "\n", f"CODE RESPONSE: {code_resp}")
Expand Down Expand Up @@ -446,10 +446,13 @@ def literature_review(self):
@return: (bool) whether to repeat the phase
"""
arx_eng = ArxivSearch()
max_tries = self.max_steps * 5 # lit review often requires extra steps
max_tries = self.max_steps * 5 # lit review often requires extra steps

# get initial response from PhD agent
resp = self.phd.inference(self.research_topic, "literature review", step=0, temp=0.8)
if self.verbose: print(resp, "\n~~~~~~~~~~~")
if self.verbose:
print(resp, "\n~~~~~~~~~~~")

# iterate until max num tries to complete task is exhausted
for _i in range(max_tries):
feedback = str()
Expand All @@ -463,40 +466,71 @@ def literature_review(self):
# grab full text from arxiv ID
elif "```FULL_TEXT" in resp:
query = extract_prompt(resp, "FULL_TEXT")
# expiration timer so that paper does not remain in context too long
arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n" + arx_eng.retrieve_full_paper_text(query) + "```"
try:
# expiration timer so that paper does not remain in context too long
full_text_content = arx_eng.retrieve_full_paper_text(query)
except Exception as e:
# Catch any unexpected errors from arxiv.Client()
err_msg = f"[ERROR] Could not retrieve paper. Possibly invalid arXiv ID. Error: {e}"
full_text_content = err_msg

# In case retrieve_full_paper_text returns an error string
# or if we want to unify it, e.g. "[ERROR] ...something"
if full_text_content.startswith("[ERROR]"):
# We won't crash; just pass that as feedback
arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n{full_text_content}```"
else:
# normal successful retrieval
arxiv_paper = f"```EXPIRATION {self.arxiv_paper_exp_time}\n{full_text_content}```"

feedback = arxiv_paper

# if add paper, extract and add to lit review, provide feedback
elif "```ADD_PAPER" in resp:
query = extract_prompt(resp, "ADD_PAPER")
feedback, text = self.phd.add_review(query, arx_eng)
# If we want to store reference text for later usage
if len(self.reference_papers) < self.num_ref_papers:
self.reference_papers.append(text)

# completion condition
if len(self.phd.lit_review) >= self.num_papers_lit_review:
# generate formal review
lit_review_sum = self.phd.format_review()

# if human in loop -> check if human is happy with the produced review
if self.human_in_loop_flag["literature review"]:
retry = self.human_in_loop("literature review", lit_review_sum)
# if not happy, repeat the process with human feedback
if retry:
self.phd.lit_review = []
return retry

# otherwise, return lit review and move on to next stage
if self.verbose: print(self.phd.lit_review_sum)
if self.verbose:
print(self.phd.lit_review_sum)
# set agent
self.set_agent_attr("lit_review_sum", lit_review_sum)
# reset agent state
self.reset_agents()
self.statistics_per_phase["literature review"]["steps"] = _i
return False
resp = self.phd.inference(self.research_topic, "literature review", feedback=feedback, step=_i + 1, temp=0.8)
if self.verbose: print(resp, "\n~~~~~~~~~~~")

# Move on to the next iteration with new feedback
resp = self.phd.inference(
self.research_topic,
"literature review",
feedback=feedback,
step=_i + 1,
temp=0.8
)
if self.verbose:
print(resp, "\n~~~~~~~~~~~")

# If we exceed max_tries:
raise Exception("Max tries during phase: Literature Review")


def human_in_loop(self, phase, phase_prod):
"""
Get human feedback for phase output
Expand Down Expand Up @@ -611,6 +645,7 @@ def parse_arguments():
help='Total number of paper-solver steps'
)

parser.add_argument('--file-path', type=str, default=None)

return parser.parse_args()

Expand All @@ -622,6 +657,8 @@ def parse_arguments():
human_mode = args.copilot_mode.lower() == "true"
compile_pdf = args.compile_latex.lower() == "true"
load_existing = args.load_existing.lower() == "true"
file_path = args.file_path

try:
num_papers_lit_review = int(args.num_papers_lit_review.lower())
except Exception:
Expand Down Expand Up @@ -654,6 +691,13 @@ def parse_arguments():
else:
research_topic = args.research_topic

if file_path and "{FILE}" in research_topic:
with open(file_path, 'r', encoding='utf-8') as f:
file_content = f.read()
# Replace the placeholder with the entire file text
research_topic = research_topic.replace("{FILE}", file_content)


task_notes_LLM = [
{"phases": ["plan formulation"],
"note": f"You should come up with a plan for TWO experiments."},
Expand Down
5 changes: 4 additions & 1 deletion inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from openai import OpenAI
import openai
import os, anthropic, json
from utils import clip_tokens

TOKENS_IN = dict()
TOKENS_OUT = dict()
Expand Down Expand Up @@ -29,7 +30,7 @@ def curr_cost_est():
}
return sum([costmap_in[_]*TOKENS_IN[_] for _ in TOKENS_IN]) + sum([costmap_out[_]*TOKENS_OUT[_] for _ in TOKENS_OUT])

def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic_api_key=None, tries=5, timeout=5.0, temp=None, print_cost=True, version="1.5"):
def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic_api_key=None, tries=5, timeout=5.0, temp=None, print_cost=True, version="1.5", max_context_tokens=128000):
preloaded_api = os.getenv('OPENAI_API_KEY')
if openai_api_key is None and preloaded_api is not None:
openai_api_key = preloaded_api
Expand All @@ -47,6 +48,8 @@ def query_model(model_str, prompt, system_prompt, openai_api_key=None, anthropic
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}]

messages = clip_tokens(messages, model=model_str, max_tokens=max_context_tokens)
if version == "0.28":
if temp is None:
completion = openai.ChatCompletion.create(
Expand Down
28 changes: 25 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,32 +10,41 @@ arxiv==2.1.3
astunparse==1.6.3
async-timeout==5.0.1
attrs==24.2.0
beautifulsoup4==4.12.3
blis==1.0.1
catalogue==2.0.10
certifi==2024.8.30
charset-normalizer==3.4.0
click==8.1.7
cloudpathlib==0.20.0
cloudpickle==3.1.1
confection==0.1.5
contourpy==1.3.0
cycler==0.12.1
cymem==2.0.10
datasets==3.1.0
diffusers==0.31.0
dill==0.3.8
dill==0.3.9
distro==1.9.0
EMD-signal @ git+https://github.com/laszukdawid/PyEMD.git@4fc40017c1db8f1fceda4370a12314e1dedf8dde
exceptiongroup==1.2.2
Farama-Notifications==0.0.4
feedparser==6.0.11
filelock==3.16.1
flatbuffers==24.3.25
fonttools==4.55.0
frozendict==2.4.6
frozenlist==1.5.0
fsspec==2024.9.0
gast==0.6.0
google-pasta==0.2.0
grpcio==1.68.0
gym==0.26.2
gym-notices==0.0.8
gymnasium==1.0.0
h11==0.14.0
h5py==3.12.1
html5lib==1.1
httpcore==1.0.7
httpx==0.27.2
huggingface-hub==0.26.2
Expand All @@ -52,6 +61,7 @@ langcodes==3.5.0
language_data==1.3.0
lazy_loader==0.4
libclang==18.1.1
lxml==5.3.0
marisa-trie==1.2.1
Markdown==3.7
markdown-it-py==3.0.0
Expand All @@ -61,33 +71,41 @@ mdurl==0.1.2
ml-dtypes==0.4.1
mpmath==1.3.0
multidict==6.1.0
multiprocess==0.70.16
multiprocess==0.70.17
multitasking==0.0.11
murmurhash==1.0.11
namex==0.0.8
nest-asyncio==1.6.0
networkx==3.2.1
nltk==3.9.1
numpy==2.0.2
numpy==1.26.4
openai==1.55.1
opt_einsum==3.4.0
optree==0.13.1
packaging==24.2
pandas==2.2.3
pathos==0.3.3
patsy==1.0.1
peewee==3.17.8
pillow==11.0.0
platformdirs==4.3.6
plotly==5.24.1
pox==0.3.5
ppft==1.7.6.9
preshed==3.0.9
propcache==0.2.0
protobuf==5.28.3
psutil==6.1.0
pyarrow==18.1.0
pydantic==2.10.2
pydantic_core==2.27.1
pyemd==1.0.0
Pygments==2.18.0
pyparsing==3.2.0
pypdf==5.1.0
python-dateutil==2.9.0.post0
pytz==2024.2
PyWavelets==1.8.0
PyYAML==6.0.2
regex==2024.11.6
requests==2.32.3
Expand All @@ -104,10 +122,12 @@ shellingham==1.5.4
six==1.16.0
smart-open==7.0.5
sniffio==1.3.1
soupsieve==2.6
spacy==3.8.2
spacy-legacy==3.0.12
spacy-loggers==1.0.5
srsly==2.4.8
stable_baselines3==2.4.1
statsmodels==0.14.4
sympy==1.13.1
tenacity==9.0.0
Expand All @@ -130,8 +150,10 @@ tzdata==2024.2
urllib3==2.2.3
wasabi==1.1.3
weasel==0.4.1
webencodings==0.5.1
Werkzeug==3.1.3
wrapt==1.17.0
xxhash==3.5.0
yarl==1.18.0
yfinance==0.2.51
zipp==3.21.0
Loading