Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 108 additions & 2 deletions ai_scientist/generate_ideas.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,93 @@
If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES."""

# generate overview paragraph
def generate_overview_paragraph(base_dir, client, model, engine="semanticscholar"):
with open(osp.join(base_dir, "experiment.py"), "r") as f:
code = f.read()
with open(osp.join(base_dir, "prompt.json"), "r") as f:
prompt = json.load(f)

task_description = prompt["task_description"]

# use LLM to generate search queries
query_generation_prompt = f"""You are a machine learning researcher preparing a literature review.
Based on the following research task and code snippet, generate 3 to 5 search queries that would retrieve relevant papers from an academic search engine like Semantic Scholar.

TASK DESCRIPTION:
{task_description}

CODE:
<experiment.py>
{code}
</experiment.py>

Return only the search queries in JSON list format:
```json
[
"query 1",
"query 2",
...
]
```"""

query_text, _ = get_response_from_llm(
query_generation_prompt,
client=client,
model=model,
system_message="You generate search queries for academic literature.",
msg_history=[]
)

try:
query_list = extract_json_between_markers(query_text)
assert isinstance(query_list, list) and len(query_list) > 0
except Exception as e:
print("Failed to generate or parse search queries:", e)
return "Failed to generate search queries."

print("Generated queries:", query_list)

# use the generated queries to retrieve papers
all_paper_summaries = []
for query in query_list:
try:
papers = search_for_papers(query, result_limit=5, engine=engine)
if not papers:
continue
for p in papers:
all_paper_summaries.append(f"Title: {p['title']}\nAbstract: {p['abstract']}\n")
except Exception as e:
print(f"Query failed: {query}, error: {e}")
continue

if len(all_paper_summaries) == 0:
print("No relevant papers found across all queries.")
return "No relevant papers found to summarize."

# summarize all paper abstracts
paper_str = "\n\n".join(all_paper_summaries)

overview_prompt = f"""You are a highly skilled AI researcher. The following papers were retrieved related to the task:

{paper_str}

Summarize the current research directions and trends based on these papers. Write a concise and insightful paragraph to help guide future experiment ideas. Only output the paragraph without any commentary."""

overview_text, _ = get_response_from_llm(
overview_prompt,
client=client,
model=model,
system_message="You summarize academic papers into insightful overviews.",
msg_history=[]
)

with open(osp.join(base_dir, "overview_prompt.txt"), "w") as f:
f.write(overview_prompt.strip())
with open(osp.join(base_dir, "overview.txt"), "w") as f:
f.write(overview_text.strip())

return overview_text.strip()

# GENERATE IDEAS
def generate_ideas(
Expand Down Expand Up @@ -108,6 +195,13 @@ def generate_ideas(
prompt = json.load(f)

idea_system_prompt = prompt["system"]

# add overview as a condition
with open(osp.join(base_dir, "overview.txt"), "r") as f:
overview_paragraph = f.read()

# And include in idea_first_prompt:
idea_first_prompt_with_overview = "Current research overview:\n" + overview_paragraph + "\n\n" + idea_first_prompt

for _ in range(max_num_generations):
print()
Expand All @@ -118,7 +212,8 @@ def generate_ideas(
msg_history = []
print(f"Iteration 1/{num_reflections}")
text, msg_history = get_response_from_llm(
idea_first_prompt.format(
# idea_first_prompt.format(
idea_first_prompt_with_overview.format(
task_description=prompt["task_description"],
code=code,
prev_ideas_string=prev_ideas_string,
Expand Down Expand Up @@ -202,6 +297,13 @@ def generate_next_idea(
prompt = json.load(f)
idea_system_prompt = prompt["system"]

# add overview as a condition
with open(osp.join(base_dir, "overview.txt"), "r") as f:
overview_paragraph = f.read()

# And include in idea_first_prompt:
idea_first_prompt_with_overview = "Current research overview:\n" + overview_paragraph + "\n\n" + idea_first_prompt

for _ in range(max_attempts):
try:
idea_strings = []
Expand All @@ -212,7 +314,8 @@ def generate_next_idea(
msg_history = []
print(f"Iteration 1/{num_reflections}")
text, msg_history = get_response_from_llm(
idea_first_prompt.format(
# idea_first_prompt.format(
idea_first_prompt_with_overview.format(
task_description=prompt["task_description"],
code=code,
prev_ideas_string=prev_ideas_string,
Expand Down Expand Up @@ -529,6 +632,9 @@ def check_idea_novelty(

base_dir = osp.join("templates", args.experiment)
results_dir = osp.join("results", args.experiment)

overview = generate_overview_paragraph(base_dir, client, client_model)

ideas = generate_ideas(
base_dir,
client=client,
Expand Down
Loading