diff --git a/ai_scientist/generate_ideas.py b/ai_scientist/generate_ideas.py
index 7b53b81b..20d8c234 100644
--- a/ai_scientist/generate_ideas.py
+++ b/ai_scientist/generate_ideas.py
@@ -71,6 +71,93 @@
If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES."""
+# generate overview paragraph
+def generate_overview_paragraph(base_dir, client, model, engine="semanticscholar"):
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+
+ task_description = prompt["task_description"]
+
+ # use LLM to generate search queries
+ query_generation_prompt = f"""You are a machine learning researcher preparing a literature review.
+Based on the following research task and code snippet, generate 3 to 5 search queries that would retrieve relevant papers from an academic search engine like Semantic Scholar.
+
+TASK DESCRIPTION:
+{task_description}
+
+CODE:
+
+{code}
+
+
+Return only the search queries in JSON list format:
+```json
+[
+ "query 1",
+ "query 2",
+ ...
+]
+```"""
+
+ query_text, _ = get_response_from_llm(
+ query_generation_prompt,
+ client=client,
+ model=model,
+ system_message="You generate search queries for academic literature.",
+ msg_history=[]
+ )
+
+ try:
+ query_list = extract_json_between_markers(query_text)
+ assert isinstance(query_list, list) and len(query_list) > 0
+ except Exception as e:
+ print("Failed to generate or parse search queries:", e)
+ return "Failed to generate search queries."
+
+ print("Generated queries:", query_list)
+
+ # use the generated queries to retrieve papers
+ all_paper_summaries = []
+ for query in query_list:
+ try:
+ papers = search_for_papers(query, result_limit=5, engine=engine)
+ if not papers:
+ continue
+ for p in papers:
+ all_paper_summaries.append(f"Title: {p['title']}\nAbstract: {p['abstract']}\n")
+ except Exception as e:
+ print(f"Query failed: {query}, error: {e}")
+ continue
+
+ if len(all_paper_summaries) == 0:
+ print("No relevant papers found across all queries.")
+ return "No relevant papers found to summarize."
+
+ # summarize all paper abstracts
+ paper_str = "\n\n".join(all_paper_summaries)
+
+ overview_prompt = f"""You are a highly skilled AI researcher. The following papers were retrieved related to the task:
+
+{paper_str}
+
+Summarize the current research directions and trends based on these papers. Write a concise and insightful paragraph to help guide future experiment ideas. Only output the paragraph without any commentary."""
+
+ overview_text, _ = get_response_from_llm(
+ overview_prompt,
+ client=client,
+ model=model,
+ system_message="You summarize academic papers into insightful overviews.",
+ msg_history=[]
+ )
+
+ with open(osp.join(base_dir, "overview_prompt.txt"), "w") as f:
+ f.write(overview_prompt.strip())
+ with open(osp.join(base_dir, "overview.txt"), "w") as f:
+ f.write(overview_text.strip())
+
+ return overview_text.strip()
# GENERATE IDEAS
def generate_ideas(
@@ -108,6 +195,13 @@ def generate_ideas(
prompt = json.load(f)
idea_system_prompt = prompt["system"]
+
+ # add overview as a condition
+ with open(osp.join(base_dir, "overview.txt"), "r") as f:
+ overview_paragraph = f.read()
+
+ # And include in idea_first_prompt:
+ idea_first_prompt_with_overview = "Current research overview:\n" + overview_paragraph + "\n\n" + idea_first_prompt
for _ in range(max_num_generations):
print()
@@ -118,7 +212,8 @@ def generate_ideas(
msg_history = []
print(f"Iteration 1/{num_reflections}")
text, msg_history = get_response_from_llm(
- idea_first_prompt.format(
+ # idea_first_prompt.format(
+ idea_first_prompt_with_overview.format(
task_description=prompt["task_description"],
code=code,
prev_ideas_string=prev_ideas_string,
@@ -202,6 +297,13 @@ def generate_next_idea(
prompt = json.load(f)
idea_system_prompt = prompt["system"]
+ # add overview as a condition
+ with open(osp.join(base_dir, "overview.txt"), "r") as f:
+ overview_paragraph = f.read()
+
+ # And include in idea_first_prompt:
+ idea_first_prompt_with_overview = "Current research overview:\n" + overview_paragraph + "\n\n" + idea_first_prompt
+
for _ in range(max_attempts):
try:
idea_strings = []
@@ -212,7 +314,8 @@ def generate_next_idea(
msg_history = []
print(f"Iteration 1/{num_reflections}")
text, msg_history = get_response_from_llm(
- idea_first_prompt.format(
+ # idea_first_prompt.format(
+ idea_first_prompt_with_overview.format(
task_description=prompt["task_description"],
code=code,
prev_ideas_string=prev_ideas_string,
@@ -529,6 +632,9 @@ def check_idea_novelty(
base_dir = osp.join("templates", args.experiment)
results_dir = osp.join("results", args.experiment)
+
+ overview = generate_overview_paragraph(base_dir, client, client_model)
+
ideas = generate_ideas(
base_dir,
client=client,
diff --git a/ai_scientist/llm.py b/ai_scientist/llm.py
index 27c9eee8..20d8c234 100644
--- a/ai_scientist/llm.py
+++ b/ai_scientist/llm.py
@@ -1,351 +1,652 @@
import json
import os
-import re
+import os.path as osp
+import time
+from typing import List, Dict, Union
-import anthropic
import backoff
-import openai
-import google.generativeai as genai
-from google.generativeai.types import GenerationConfig
-
-MAX_NUM_TOKENS = 4096
-
-AVAILABLE_LLMS = [
- # Anthropic models
- "claude-3-5-sonnet-20240620",
- "claude-3-5-sonnet-20241022",
- # OpenAI models
- "gpt-4o-mini",
- "gpt-4o-mini-2024-07-18",
- "gpt-4o",
- "gpt-4o-2024-05-13",
- "gpt-4o-2024-08-06",
- "gpt-4.1",
- "gpt-4.1-2025-04-14",
- "gpt-4.1-mini",
- "gpt-4.1-mini-2025-04-14",
- "gpt-4.1-nano",
- "gpt-4.1-nano-2025-04-14",
- "o1",
- "o1-2024-12-17",
- "o1-preview-2024-09-12",
- "o1-mini",
- "o1-mini-2024-09-12",
- "o3-mini",
- "o3-mini-2025-01-31",
- # OpenRouter models
- "llama3.1-405b",
- # Anthropic Claude models via Amazon Bedrock
- "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
- "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
- "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
- "bedrock/anthropic.claude-3-haiku-20240307-v1:0",
- "bedrock/anthropic.claude-3-opus-20240229-v1:0",
- # Anthropic Claude models Vertex AI
- "vertex_ai/claude-3-opus@20240229",
- "vertex_ai/claude-3-5-sonnet@20240620",
- "vertex_ai/claude-3-5-sonnet-v2@20241022",
- "vertex_ai/claude-3-sonnet@20240229",
- "vertex_ai/claude-3-haiku@20240307",
- # DeepSeek models
- "deepseek-chat",
- "deepseek-coder",
- "deepseek-reasoner",
- # Google Gemini models
- "gemini-1.5-flash",
- "gemini-1.5-pro",
- "gemini-2.0-flash",
- "gemini-2.0-flash-lite",
- "gemini-2.0-flash-thinking-exp-01-21",
- "gemini-2.5-pro-preview-03-25",
- "gemini-2.5-pro-exp-03-25",
+import requests
+
+from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
+
+S2_API_KEY = os.getenv("S2_API_KEY")
+
+idea_first_prompt = """{task_description}
+
+{code}
+
+
+Here are the ideas that you have already generated:
+
+'''
+{prev_ideas_string}
+'''
+
+Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided.
+Note that you will not have access to any additional resources or datasets.
+Make sure any idea is not overfit the specific training dataset or model, and has wider significance.
+
+Respond in the following format:
+
+THOUGHT:
+
+
+NEW IDEA JSON:
+```json
+
+```
+
+In , first briefly discuss your intuitions and motivations for the idea. Detail your high-level plan, necessary design choices and ideal outcomes of the experiments. Justify how the idea is different from the existing ones.
+
+In , provide the new idea in JSON format with the following fields:
+- "Name": A shortened descriptor of the idea. Lowercase, no spaces, underscores allowed.
+- "Title": A title for the idea, will be used for the report writing.
+- "Experiment": An outline of the implementation. E.g. which functions need to be added or modified, how results will be obtained, ...
+- "Interestingness": A rating from 1 to 10 (lowest to highest).
+- "Feasibility": A rating from 1 to 10 (lowest to highest).
+- "Novelty": A rating from 1 to 10 (lowest to highest).
+
+Be cautious and realistic on your ratings.
+This JSON will be automatically parsed, so ensure the format is precise.
+You will have {num_reflections} rounds to iterate on the idea, but do not need to use them all.
+"""
+
+idea_reflection_prompt = """Round {current_round}/{num_reflections}.
+In your thoughts, first carefully consider the quality, novelty, and feasibility of the idea you just created.
+Include any other factors that you think are important in evaluating the idea.
+Ensure the idea is clear and concise, and the JSON is the correct format.
+Do not make things overly complicated.
+In the next attempt, try and refine and improve your idea.
+Stick to the spirit of the original idea unless there are glaring issues.
+
+Respond in the same format as before:
+THOUGHT:
+
+
+NEW IDEA JSON:
+```json
+
+```
+
+If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
+ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES."""
+
+# generate overview paragraph
+def generate_overview_paragraph(base_dir, client, model, engine="semanticscholar"):
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+
+ task_description = prompt["task_description"]
+
+ # use LLM to generate search queries
+ query_generation_prompt = f"""You are a machine learning researcher preparing a literature review.
+Based on the following research task and code snippet, generate 3 to 5 search queries that would retrieve relevant papers from an academic search engine like Semantic Scholar.
+
+TASK DESCRIPTION:
+{task_description}
+
+CODE:
+
+{code}
+
+
+Return only the search queries in JSON list format:
+```json
+[
+ "query 1",
+ "query 2",
+ ...
]
+```"""
+
+ query_text, _ = get_response_from_llm(
+ query_generation_prompt,
+ client=client,
+ model=model,
+ system_message="You generate search queries for academic literature.",
+ msg_history=[]
+ )
+
+ try:
+ query_list = extract_json_between_markers(query_text)
+ assert isinstance(query_list, list) and len(query_list) > 0
+ except Exception as e:
+ print("Failed to generate or parse search queries:", e)
+ return "Failed to generate search queries."
+
+ print("Generated queries:", query_list)
+
+ # use the generated queries to retrieve papers
+ all_paper_summaries = []
+ for query in query_list:
+ try:
+ papers = search_for_papers(query, result_limit=5, engine=engine)
+ if not papers:
+ continue
+ for p in papers:
+ all_paper_summaries.append(f"Title: {p['title']}\nAbstract: {p['abstract']}\n")
+ except Exception as e:
+ print(f"Query failed: {query}, error: {e}")
+ continue
+
+ if len(all_paper_summaries) == 0:
+ print("No relevant papers found across all queries.")
+ return "No relevant papers found to summarize."
+
+ # summarize all paper abstracts
+ paper_str = "\n\n".join(all_paper_summaries)
+
+ overview_prompt = f"""You are a highly skilled AI researcher. The following papers were retrieved related to the task:
+{paper_str}
-# Get N responses from a single message, used for ensembling.
-@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
-def get_batch_responses_from_llm(
- msg,
+Summarize the current research directions and trends based on these papers. Write a concise and insightful paragraph to help guide future experiment ideas. Only output the paragraph without any commentary."""
+
+ overview_text, _ = get_response_from_llm(
+ overview_prompt,
+ client=client,
+ model=model,
+ system_message="You summarize academic papers into insightful overviews.",
+ msg_history=[]
+ )
+
+ with open(osp.join(base_dir, "overview_prompt.txt"), "w") as f:
+ f.write(overview_prompt.strip())
+ with open(osp.join(base_dir, "overview.txt"), "w") as f:
+ f.write(overview_text.strip())
+
+ return overview_text.strip()
+
+# GENERATE IDEAS
+def generate_ideas(
+ base_dir,
client,
model,
- system_message,
- print_debug=False,
- msg_history=None,
- temperature=0.75,
- n_responses=1,
+ skip_generation=False,
+ max_num_generations=20,
+ num_reflections=5,
):
- if msg_history is None:
- msg_history = []
+ if skip_generation:
+ # Load existing ideas from file
+ try:
+ with open(osp.join(base_dir, "ideas.json"), "r") as f:
+ ideas = json.load(f)
+ print("Loaded existing ideas:")
+ for idea in ideas:
+ print(idea)
+ return ideas
+ except FileNotFoundError:
+ print("No existing ideas found. Generating new ideas.")
+ except json.JSONDecodeError:
+ print("Error decoding existing ideas. Generating new ideas.")
- if 'gpt' in model:
- new_msg_history = msg_history + [{"role": "user", "content": msg}]
- response = client.chat.completions.create(
- model=model,
- messages=[
- {"role": "system", "content": system_message},
- *new_msg_history,
- ],
- temperature=temperature,
- max_tokens=MAX_NUM_TOKENS,
- n=n_responses,
- stop=None,
- seed=0,
- )
- content = [r.message.content for r in response.choices]
- new_msg_history = [
- new_msg_history + [{"role": "assistant", "content": c}] for c in content
- ]
- elif model == "llama-3-1-405b-instruct":
- new_msg_history = msg_history + [{"role": "user", "content": msg}]
- response = client.chat.completions.create(
- model="meta-llama/llama-3.1-405b-instruct",
- messages=[
- {"role": "system", "content": system_message},
- *new_msg_history,
- ],
- temperature=temperature,
- max_tokens=MAX_NUM_TOKENS,
- n=n_responses,
- stop=None,
- )
- content = [r.message.content for r in response.choices]
- new_msg_history = [
- new_msg_history + [{"role": "assistant", "content": c}] for c in content
- ]
- else:
- content, new_msg_history = [], []
- for _ in range(n_responses):
- c, hist = get_response_from_llm(
- msg,
- client,
- model,
- system_message,
- print_debug=False,
- msg_history=None,
- temperature=temperature,
- )
- content.append(c)
- new_msg_history.append(hist)
+ idea_str_archive = []
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
+ seed_ideas = json.load(f)
+ for seed_idea in seed_ideas:
+ idea_str_archive.append(json.dumps(seed_idea))
- if print_debug:
- print()
- print("*" * 20 + " LLM START " + "*" * 20)
- for j, msg in enumerate(new_msg_history[0]):
- print(f'{j}, {msg["role"]}: {msg["content"]}')
- print(content)
- print("*" * 21 + " LLM END " + "*" * 21)
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+
+ idea_system_prompt = prompt["system"]
+
+ # add overview as a condition
+ with open(osp.join(base_dir, "overview.txt"), "r") as f:
+ overview_paragraph = f.read()
+
+ # And include in idea_first_prompt:
+ idea_first_prompt_with_overview = "Current research overview:\n" + overview_paragraph + "\n\n" + idea_first_prompt
+
+ for _ in range(max_num_generations):
print()
+ print(f"Generating idea {_ + 1}/{max_num_generations}")
+ try:
+ prev_ideas_string = "\n\n".join(idea_str_archive)
+
+ msg_history = []
+ print(f"Iteration 1/{num_reflections}")
+ text, msg_history = get_response_from_llm(
+ # idea_first_prompt.format(
+ idea_first_prompt_with_overview.format(
+ task_description=prompt["task_description"],
+ code=code,
+ prev_ideas_string=prev_ideas_string,
+ num_reflections=num_reflections,
+ ),
+ client=client,
+ model=model,
+ system_message=idea_system_prompt,
+ msg_history=msg_history,
+ )
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert json_output is not None, "Failed to extract JSON from LLM output"
+ print(json_output)
+
+ # Iteratively improve task.
+ if num_reflections > 1:
+ for j in range(num_reflections - 1):
+ print(f"Iteration {j + 2}/{num_reflections}")
+ text, msg_history = get_response_from_llm(
+ idea_reflection_prompt.format(
+ current_round=j + 2, num_reflections=num_reflections
+ ),
+ client=client,
+ model=model,
+ system_message=idea_system_prompt,
+ msg_history=msg_history,
+ )
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert (
+ json_output is not None
+ ), "Failed to extract JSON from LLM output"
+ print(json_output)
+
+ if "I am done" in text:
+ print(f"Idea generation converged after {j + 2} iterations.")
+ break
+
+ idea_str_archive.append(json.dumps(json_output))
+ except Exception as e:
+ print(f"Failed to generate idea: {e}")
+ continue
+
+ ## SAVE IDEAS
+ ideas = []
+ for idea_str in idea_str_archive:
+ ideas.append(json.loads(idea_str))
- return content, new_msg_history
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
+ json.dump(ideas, f, indent=4)
+ return ideas
-@backoff.on_exception(backoff.expo, (openai.RateLimitError, openai.APITimeoutError))
-def get_response_from_llm(
- msg,
+
+# GENERATE IDEAS OPEN-ENDED
+def generate_next_idea(
+ base_dir,
client,
model,
- system_message,
- print_debug=False,
- msg_history=None,
- temperature=0.75,
+ prev_idea_archive=[],
+ num_reflections=5,
+ max_attempts=10,
):
- if msg_history is None:
- msg_history = []
+ idea_archive = prev_idea_archive
+ original_archive_size = len(idea_archive)
- if "claude" in model:
- new_msg_history = msg_history + [
- {
- "role": "user",
- "content": [
- {
- "type": "text",
- "text": msg,
- }
- ],
- }
- ]
- response = client.messages.create(
- model=model,
- max_tokens=MAX_NUM_TOKENS,
- temperature=temperature,
- system=system_message,
- messages=new_msg_history,
- )
- content = response.content[0].text
- new_msg_history = new_msg_history + [
- {
- "role": "assistant",
- "content": [
- {
- "type": "text",
- "text": content,
- }
- ],
- }
- ]
- elif 'gpt' in model:
- new_msg_history = msg_history + [{"role": "user", "content": msg}]
- response = client.chat.completions.create(
- model=model,
- messages=[
- {"role": "system", "content": system_message},
- *new_msg_history,
- ],
- temperature=temperature,
- max_tokens=MAX_NUM_TOKENS,
- n=1,
- stop=None,
- seed=0,
- )
- content = response.choices[0].message.content
- new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
- elif "o1" in model or "o3" in model:
- new_msg_history = msg_history + [{"role": "user", "content": msg}]
- response = client.chat.completions.create(
- model=model,
- messages=[
- {"role": "user", "content": system_message},
- *new_msg_history,
- ],
- temperature=1,
- max_completion_tokens=MAX_NUM_TOKENS,
- n=1,
- seed=0,
- )
- content = response.choices[0].message.content
- new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
- elif model in ["meta-llama/llama-3.1-405b-instruct", "llama-3-1-405b-instruct"]:
- new_msg_history = msg_history + [{"role": "user", "content": msg}]
- response = client.chat.completions.create(
- model="meta-llama/llama-3.1-405b-instruct",
- messages=[
- {"role": "system", "content": system_message},
- *new_msg_history,
- ],
- temperature=temperature,
- max_tokens=MAX_NUM_TOKENS,
- n=1,
- stop=None,
- )
- content = response.choices[0].message.content
- new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
- elif model in ["deepseek-chat", "deepseek-coder"]:
- new_msg_history = msg_history + [{"role": "user", "content": msg}]
- response = client.chat.completions.create(
- model=model,
- messages=[
- {"role": "system", "content": system_message},
- *new_msg_history,
- ],
- temperature=temperature,
- max_tokens=MAX_NUM_TOKENS,
- n=1,
- stop=None,
- )
- content = response.choices[0].message.content
- new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
- elif model in ["deepseek-reasoner"]:
- new_msg_history = msg_history + [{"role": "user", "content": msg}]
- response = client.chat.completions.create(
- model=model,
- messages=[
- {"role": "system", "content": system_message},
- *new_msg_history,
- ],
- n=1,
- stop=None,
- )
- content = response.choices[0].message.content
- new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
- elif "gemini" in model:
- new_msg_history = msg_history + [{"role": "user", "content": msg}]
- response = client.chat.completions.create(
- model=model,
- messages=[
- {"role": "system", "content": system_message},
- *new_msg_history,
- ],
- temperature=temperature,
- max_tokens=MAX_NUM_TOKENS,
- n=1,
+ print(f"Generating idea {original_archive_size + 1}")
+
+ if len(prev_idea_archive) == 0:
+ print(f"First iteration, taking seed ideas")
+ # seed the archive on the first run with pre-existing ideas
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
+ seed_ideas = json.load(f)
+ for seed_idea in seed_ideas[:1]:
+ idea_archive.append(seed_idea)
+ else:
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+ idea_system_prompt = prompt["system"]
+
+ # add overview as a condition
+ with open(osp.join(base_dir, "overview.txt"), "r") as f:
+ overview_paragraph = f.read()
+
+ # And include in idea_first_prompt:
+ idea_first_prompt_with_overview = "Current research overview:\n" + overview_paragraph + "\n\n" + idea_first_prompt
+
+ for _ in range(max_attempts):
+ try:
+ idea_strings = []
+ for idea in idea_archive:
+ idea_strings.append(json.dumps(idea))
+ prev_ideas_string = "\n\n".join(idea_strings)
+
+ msg_history = []
+ print(f"Iteration 1/{num_reflections}")
+ text, msg_history = get_response_from_llm(
+ # idea_first_prompt.format(
+ idea_first_prompt_with_overview.format(
+ task_description=prompt["task_description"],
+ code=code,
+ prev_ideas_string=prev_ideas_string,
+ num_reflections=num_reflections,
+ )
+ + """
+Completed ideas have an additional "Score" field which indicates the assessment by an expert ML reviewer.
+This is on a standard 1-10 ML conference scale.
+Scores of 0 indicate the idea failed either during experimentation, writeup or reviewing.
+""",
+ client=client,
+ model=model,
+ system_message=idea_system_prompt,
+ msg_history=msg_history,
+ )
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert json_output is not None, "Failed to extract JSON from LLM output"
+ print(json_output)
+
+ # Iteratively improve task.
+ if num_reflections > 1:
+ for j in range(num_reflections - 1):
+ print(f"Iteration {j + 2}/{num_reflections}")
+ text, msg_history = get_response_from_llm(
+ idea_reflection_prompt.format(
+ current_round=j + 2, num_reflections=num_reflections
+ ),
+ client=client,
+ model=model,
+ system_message=idea_system_prompt,
+ msg_history=msg_history,
+ )
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert (
+ json_output is not None
+ ), "Failed to extract JSON from LLM output"
+ print(json_output)
+
+ if "I am done" in text:
+ print(
+ f"Idea generation converged after {j + 2} iterations."
+ )
+ break
+
+ idea_archive.append(json_output)
+ break
+ except Exception as e:
+ print(f"Failed to generate idea: {e}")
+ continue
+
+ ## SAVE IDEAS
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
+ json.dump(idea_archive, f, indent=4)
+
+ return idea_archive
+
+
+def on_backoff(details):
+ print(
+ f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
+ f"calling function {details['target'].__name__} at {time.strftime('%X')}"
+ )
+
+
+@backoff.on_exception(
+ backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
+)
+def search_for_papers(query, result_limit=10, engine="semanticscholar") -> Union[None, List[Dict]]:
+ if not query:
+ return None
+ if engine == "semanticscholar":
+ rsp = requests.get(
+ "https://api.semanticscholar.org/graph/v1/paper/search",
+ headers={"X-API-KEY": S2_API_KEY} if S2_API_KEY else {},
+ params={
+ "query": query,
+ "limit": result_limit,
+ "fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
+ },
)
- content = response.choices[0].message.content
- new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
+ print(f"Response Status Code: {rsp.status_code}")
+ print(
+ f"Response Content: {rsp.text[:500]}"
+ ) # Print the first 500 characters of the response content
+ rsp.raise_for_status()
+ results = rsp.json()
+ total = results["total"]
+ time.sleep(1.0)
+ if not total:
+ return None
+
+ papers = results["data"]
+ return papers
+ elif engine == "openalex":
+ import pyalex
+ from pyalex import Work, Works
+ mail = os.environ.get("OPENALEX_MAIL_ADDRESS", None)
+ if mail is None:
+ print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better access to OpenAlex API!")
+ else:
+ pyalex.config.email = mail
+
+ def extract_info_from_work(work: Work, max_abstract_length: int = 1000) -> dict[str, str]:
+ # "Unknown" is returned when venue is unknown...
+ venue = "Unknown"
+ for i, location in enumerate(work["locations"]):
+ if location["source"] is not None:
+ venue = location["source"]["display_name"]
+ if venue != "":
+ break
+ title = work["title"]
+ abstract = work["abstract"]
+ if abstract is None:
+ abstract = ""
+ if len(abstract) > max_abstract_length:
+ # To avoid context length exceed error.
+ print(f"[WARNING] {title=}: {len(abstract)=} is too long! Use first {max_abstract_length} chars.")
+ abstract = abstract[:max_abstract_length]
+ authors_list = [author["author"]["display_name"] for author in work["authorships"]]
+ authors = " and ".join(authors_list) if len(authors_list) < 20 else f"{authors_list[0]} et al."
+ paper = dict(
+ title=title,
+ authors=authors,
+ venue=venue,
+ year=work["publication_year"],
+ abstract=abstract,
+ citationCount=work["cited_by_count"],
+ )
+ return paper
+
+ works: List[Dict] = Works().search(query).get(per_page=result_limit)
+ papers: List[Dict[str, str]] = [extract_info_from_work(work) for work in works]
+ return papers
else:
- raise ValueError(f"Model {model} not supported.")
+ raise NotImplementedError(f"{engine=} not supported!")
- if print_debug:
- print()
- print("*" * 20 + " LLM START " + "*" * 20)
- for j, msg in enumerate(new_msg_history):
- print(f'{j}, {msg["role"]}: {msg["content"]}')
- print(content)
- print("*" * 21 + " LLM END " + "*" * 21)
- print()
- return content, new_msg_history
+novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
+You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
+Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
+You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
+The top 10 results for any search query will be presented to you with the abstracts.
-def extract_json_between_markers(llm_output):
- # Regular expression pattern to find JSON content between ```json and ```
- json_pattern = r"```json(.*?)```"
- matches = re.findall(json_pattern, llm_output, re.DOTALL)
+You will be given {num_rounds} to decide on the paper, but you do not need to use them all.
+At any round, you may exit early and decide on the novelty of the idea.
+Decide a paper idea is novel if after sufficient searching, you have not found a paper that significantly overlaps with your idea.
+Decide a paper idea is not novel, if you have found a paper that significantly overlaps with your idea.
- if not matches:
- # Fallback: Try to find any JSON-like content in the output
- json_pattern = r"\{.*?\}"
- matches = re.findall(json_pattern, llm_output, re.DOTALL)
+{task_description}
+
+{code}
+
+"""
- for json_string in matches:
- json_string = json_string.strip()
- try:
- parsed_json = json.loads(json_string)
- return parsed_json
- except json.JSONDecodeError:
- # Attempt to fix common JSON issues
+novelty_prompt = '''Round {current_round}/{num_rounds}.
+You have this idea:
+
+"""
+{idea}
+"""
+
+The results of the last query are (empty on first round):
+"""
+{last_query_results}
+"""
+
+Respond in the following format:
+
+THOUGHT:
+
+
+RESPONSE:
+```json
+
+```
+
+In , first briefly reason over the idea and identify any query that could help you make your decision.
+If you have made your decision, add "Decision made: novel." or "Decision made: not novel." to your thoughts.
+
+In , respond in JSON format with ONLY the following field:
+- "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
+
+A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
+This JSON will be automatically parsed, so ensure the format is precise.'''
+
+
+def check_idea_novelty(
+ ideas,
+ base_dir,
+ client,
+ model,
+ max_num_iterations=10,
+ engine="semanticscholar",
+):
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+ task_description = prompt["task_description"]
+
+ for idx, idea in enumerate(ideas):
+ if "novel" in idea:
+ print(f"Skipping idea {idx}, already checked.")
+ continue
+
+ print(f"\nChecking novelty of idea {idx}: {idea['Name']}")
+
+ novel = False
+ msg_history = []
+ papers_str = ""
+
+ for j in range(max_num_iterations):
try:
- # Remove invalid control characters
- json_string_clean = re.sub(r"[\x00-\x1F\x7F]", "", json_string)
- parsed_json = json.loads(json_string_clean)
- return parsed_json
- except json.JSONDecodeError:
- continue # Try next match
-
- return None # No valid JSON found
-
-
-def create_client(model):
- if model.startswith("claude-"):
- print(f"Using Anthropic API with model {model}.")
- return anthropic.Anthropic(), model
- elif model.startswith("bedrock") and "claude" in model:
- client_model = model.split("/")[-1]
- print(f"Using Amazon Bedrock with model {client_model}.")
- return anthropic.AnthropicBedrock(), client_model
- elif model.startswith("vertex_ai") and "claude" in model:
- client_model = model.split("/")[-1]
- print(f"Using Vertex AI with model {client_model}.")
- return anthropic.AnthropicVertex(), client_model
- elif 'gpt' in model or "o1" in model or "o3" in model:
- print(f"Using OpenAI API with model {model}.")
- return openai.OpenAI(), model
- elif model in ["deepseek-chat", "deepseek-reasoner", "deepseek-coder"]:
- print(f"Using OpenAI API with {model}.")
- return openai.OpenAI(
- api_key=os.environ["DEEPSEEK_API_KEY"],
- base_url="https://api.deepseek.com"
- ), model
- elif model == "llama3.1-405b":
- print(f"Using OpenAI API with {model}.")
- return openai.OpenAI(
- api_key=os.environ["OPENROUTER_API_KEY"],
- base_url="https://openrouter.ai/api/v1"
- ), "meta-llama/llama-3.1-405b-instruct"
- elif "gemini" in model:
- print(f"Using OpenAI API with {model}.")
- return openai.OpenAI(
- api_key=os.environ["GEMINI_API_KEY"],
- base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
- ), model
- else:
- raise ValueError(f"Model {model} not supported.")
+ text, msg_history = get_response_from_llm(
+ novelty_prompt.format(
+ current_round=j + 1,
+ num_rounds=max_num_iterations,
+ idea=idea,
+ last_query_results=papers_str,
+ ),
+ client=client,
+ model=model,
+ system_message=novelty_system_msg.format(
+ num_rounds=max_num_iterations,
+ task_description=task_description,
+ code=code,
+ ),
+ msg_history=msg_history,
+ )
+ if "decision made: novel" in text.lower():
+ print("Decision made: novel after round", j)
+ novel = True
+ break
+ if "decision made: not novel" in text.lower():
+ print("Decision made: not novel after round", j)
+ break
+
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert json_output is not None, "Failed to extract JSON from LLM output"
+
+ ## SEARCH FOR PAPERS
+ query = json_output["Query"]
+ papers = search_for_papers(query, result_limit=10, engine=engine)
+ if papers is None:
+ papers_str = "No papers found."
+
+ paper_strings = []
+ for i, paper in enumerate(papers):
+ paper_strings.append(
+ """{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
+ i=i,
+ title=paper["title"],
+ authors=paper["authors"],
+ venue=paper["venue"],
+ year=paper["year"],
+ cites=paper["citationCount"],
+ abstract=paper["abstract"],
+ )
+ )
+ papers_str = "\n\n".join(paper_strings)
+
+ except Exception as e:
+ print(f"Error: {e}")
+ continue
+
+ idea["novel"] = novel
+
+ # Save results to JSON file
+ results_file = osp.join(base_dir, "ideas.json")
+ with open(results_file, "w") as f:
+ json.dump(ideas, f, indent=4)
+
+ return ideas
+
+
+if __name__ == "__main__":
+ MAX_NUM_GENERATIONS = 32
+ NUM_REFLECTIONS = 5
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Generate AI scientist ideas")
+ # add type of experiment (nanoGPT, Boston, etc.)
+ parser.add_argument(
+ "--experiment",
+ type=str,
+ default="nanoGPT",
+ help="Experiment to run AI Scientist on.",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="gpt-4o-2024-05-13",
+ choices=AVAILABLE_LLMS,
+ help="Model to use for AI Scientist.",
+ )
+ parser.add_argument(
+ "--skip-idea-generation",
+ action="store_true",
+ help="Skip idea generation and use existing ideas.",
+ )
+ parser.add_argument(
+ "--check-novelty",
+ action="store_true",
+ help="Check novelty of ideas.",
+ )
+ args = parser.parse_args()
+
+ # Create client
+ client, client_model = create_client(args.model)
+
+ base_dir = osp.join("templates", args.experiment)
+ results_dir = osp.join("results", args.experiment)
+
+ overview = generate_overview_paragraph(base_dir, client, client_model)
+
+ ideas = generate_ideas(
+ base_dir,
+ client=client,
+ model=client_model,
+ skip_generation=args.skip_idea_generation,
+ max_num_generations=MAX_NUM_GENERATIONS,
+ num_reflections=NUM_REFLECTIONS,
+ )
+ if args.check_novelty:
+ ideas = check_idea_novelty(
+ ideas,
+ base_dir=base_dir,
+ client=client,
+ model=client_model,
+ )
diff --git a/ai_scientist/perform_experiments.py b/ai_scientist/perform_experiments.py
index bb8c248a..20d8c234 100644
--- a/ai_scientist/perform_experiments.py
+++ b/ai_scientist/perform_experiments.py
@@ -1,166 +1,652 @@
import json
+import os
import os.path as osp
-import shutil
-import subprocess
-import sys
-from subprocess import TimeoutExpired
+import time
+from typing import List, Dict, Union
-MAX_ITERS = 4
-MAX_RUNS = 5
-MAX_STDERR_OUTPUT = 1500
+import backoff
+import requests
-coder_prompt = """Your goal is to implement the following idea: {title}.
-The proposed experiment is as follows: {idea}.
-You are given a total of up to {max_runs} runs to complete the necessary experiments. You do not need to use all {max_runs}.
+from ai_scientist.llm import get_response_from_llm, extract_json_between_markers, create_client, AVAILABLE_LLMS
-First, plan the list of experiments you would like to run. For example, if you are sweeping over a specific hyperparameter, plan each value you would like to test for each run.
+S2_API_KEY = os.getenv("S2_API_KEY")
-Note that we already provide the vanilla baseline results, so you do not need to re-run it.
+idea_first_prompt = """{task_description}
+
+{code}
+
-For reference, the baseline results are as follows:
+Here are the ideas that you have already generated:
-{baseline_results}
+'''
+{prev_ideas_string}
+'''
-After you complete each change, we will run the command `python experiment.py --out_dir=run_i' where i is the run number and evaluate the results.
-YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
-You can then implement the next thing on your list."""
+Come up with the next impactful and creative idea for research experiments and directions you can feasibly investigate with the code provided.
+Note that you will not have access to any additional resources or datasets.
+Make sure any idea is not overfit the specific training dataset or model, and has wider significance.
+Respond in the following format:
-# RUN EXPERIMENT
-def run_experiment(folder_name, run_num, timeout=7200):
- cwd = osp.abspath(folder_name)
- # COPY CODE SO WE CAN SEE IT.
- shutil.copy(
- osp.join(folder_name, "experiment.py"),
- osp.join(folder_name, f"run_{run_num}.py"),
+THOUGHT:
+
+
+NEW IDEA JSON:
+```json
+
+```
+
+In , first briefly discuss your intuitions and motivations for the idea. Detail your high-level plan, necessary design choices and ideal outcomes of the experiments. Justify how the idea is different from the existing ones.
+
+In , provide the new idea in JSON format with the following fields:
+- "Name": A shortened descriptor of the idea. Lowercase, no spaces, underscores allowed.
+- "Title": A title for the idea, will be used for the report writing.
+- "Experiment": An outline of the implementation. E.g. which functions need to be added or modified, how results will be obtained, ...
+- "Interestingness": A rating from 1 to 10 (lowest to highest).
+- "Feasibility": A rating from 1 to 10 (lowest to highest).
+- "Novelty": A rating from 1 to 10 (lowest to highest).
+
+Be cautious and realistic on your ratings.
+This JSON will be automatically parsed, so ensure the format is precise.
+You will have {num_reflections} rounds to iterate on the idea, but do not need to use them all.
+"""
+
+idea_reflection_prompt = """Round {current_round}/{num_reflections}.
+In your thoughts, first carefully consider the quality, novelty, and feasibility of the idea you just created.
+Include any other factors that you think are important in evaluating the idea.
+Ensure the idea is clear and concise, and the JSON is the correct format.
+Do not make things overly complicated.
+In the next attempt, try and refine and improve your idea.
+Stick to the spirit of the original idea unless there are glaring issues.
+
+Respond in the same format as before:
+THOUGHT:
+
+
+NEW IDEA JSON:
+```json
+
+```
+
+If there is nothing to improve, simply repeat the previous JSON EXACTLY after the thought and include "I am done" at the end of the thoughts but before the JSON.
+ONLY INCLUDE "I am done" IF YOU ARE MAKING NO MORE CHANGES."""
+
+# generate overview paragraph
+def generate_overview_paragraph(base_dir, client, model, engine="semanticscholar"):
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+
+ task_description = prompt["task_description"]
+
+ # use LLM to generate search queries
+ query_generation_prompt = f"""You are a machine learning researcher preparing a literature review.
+Based on the following research task and code snippet, generate 3 to 5 search queries that would retrieve relevant papers from an academic search engine like Semantic Scholar.
+
+TASK DESCRIPTION:
+{task_description}
+
+CODE:
+
+{code}
+
+
+Return only the search queries in JSON list format:
+```json
+[
+ "query 1",
+ "query 2",
+ ...
+]
+```"""
+
+ query_text, _ = get_response_from_llm(
+ query_generation_prompt,
+ client=client,
+ model=model,
+ system_message="You generate search queries for academic literature.",
+ msg_history=[]
)
- # LAUNCH COMMAND
- command = [
- "python",
- "experiment.py",
- f"--out_dir=run_{run_num}",
- ]
try:
- result = subprocess.run(
- command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
+ query_list = extract_json_between_markers(query_text)
+ assert isinstance(query_list, list) and len(query_list) > 0
+ except Exception as e:
+ print("Failed to generate or parse search queries:", e)
+ return "Failed to generate search queries."
+
+ print("Generated queries:", query_list)
+
+ # use the generated queries to retrieve papers
+ all_paper_summaries = []
+ for query in query_list:
+ try:
+ papers = search_for_papers(query, result_limit=5, engine=engine)
+ if not papers:
+ continue
+ for p in papers:
+ all_paper_summaries.append(f"Title: {p['title']}\nAbstract: {p['abstract']}\n")
+ except Exception as e:
+ print(f"Query failed: {query}, error: {e}")
+ continue
+
+ if len(all_paper_summaries) == 0:
+ print("No relevant papers found across all queries.")
+ return "No relevant papers found to summarize."
+
+ # summarize all paper abstracts
+ paper_str = "\n\n".join(all_paper_summaries)
+
+ overview_prompt = f"""You are a highly skilled AI researcher. The following papers were retrieved related to the task:
+
+{paper_str}
+
+Summarize the current research directions and trends based on these papers. Write a concise and insightful paragraph to help guide future experiment ideas. Only output the paragraph without any commentary."""
+
+ overview_text, _ = get_response_from_llm(
+ overview_prompt,
+ client=client,
+ model=model,
+ system_message="You summarize academic papers into insightful overviews.",
+ msg_history=[]
+ )
+
+ with open(osp.join(base_dir, "overview_prompt.txt"), "w") as f:
+ f.write(overview_prompt.strip())
+ with open(osp.join(base_dir, "overview.txt"), "w") as f:
+ f.write(overview_text.strip())
+
+ return overview_text.strip()
+
+# GENERATE IDEAS
+def generate_ideas(
+ base_dir,
+ client,
+ model,
+ skip_generation=False,
+ max_num_generations=20,
+ num_reflections=5,
+):
+ if skip_generation:
+ # Load existing ideas from file
+ try:
+ with open(osp.join(base_dir, "ideas.json"), "r") as f:
+ ideas = json.load(f)
+ print("Loaded existing ideas:")
+ for idea in ideas:
+ print(idea)
+ return ideas
+ except FileNotFoundError:
+ print("No existing ideas found. Generating new ideas.")
+ except json.JSONDecodeError:
+ print("Error decoding existing ideas. Generating new ideas.")
+
+ idea_str_archive = []
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
+ seed_ideas = json.load(f)
+ for seed_idea in seed_ideas:
+ idea_str_archive.append(json.dumps(seed_idea))
+
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+
+ idea_system_prompt = prompt["system"]
+
+ # add overview as a condition
+ with open(osp.join(base_dir, "overview.txt"), "r") as f:
+ overview_paragraph = f.read()
+
+ # And include in idea_first_prompt:
+ idea_first_prompt_with_overview = "Current research overview:\n" + overview_paragraph + "\n\n" + idea_first_prompt
+
+ for _ in range(max_num_generations):
+ print()
+ print(f"Generating idea {_ + 1}/{max_num_generations}")
+ try:
+ prev_ideas_string = "\n\n".join(idea_str_archive)
+
+ msg_history = []
+ print(f"Iteration 1/{num_reflections}")
+ text, msg_history = get_response_from_llm(
+ # idea_first_prompt.format(
+ idea_first_prompt_with_overview.format(
+ task_description=prompt["task_description"],
+ code=code,
+ prev_ideas_string=prev_ideas_string,
+ num_reflections=num_reflections,
+ ),
+ client=client,
+ model=model,
+ system_message=idea_system_prompt,
+ msg_history=msg_history,
+ )
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert json_output is not None, "Failed to extract JSON from LLM output"
+ print(json_output)
+
+ # Iteratively improve task.
+ if num_reflections > 1:
+ for j in range(num_reflections - 1):
+ print(f"Iteration {j + 2}/{num_reflections}")
+ text, msg_history = get_response_from_llm(
+ idea_reflection_prompt.format(
+ current_round=j + 2, num_reflections=num_reflections
+ ),
+ client=client,
+ model=model,
+ system_message=idea_system_prompt,
+ msg_history=msg_history,
+ )
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert (
+ json_output is not None
+ ), "Failed to extract JSON from LLM output"
+ print(json_output)
+
+ if "I am done" in text:
+ print(f"Idea generation converged after {j + 2} iterations.")
+ break
+
+ idea_str_archive.append(json.dumps(json_output))
+ except Exception as e:
+ print(f"Failed to generate idea: {e}")
+ continue
+
+ ## SAVE IDEAS
+ ideas = []
+ for idea_str in idea_str_archive:
+ ideas.append(json.loads(idea_str))
+
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
+ json.dump(ideas, f, indent=4)
+
+ return ideas
+
+
+# GENERATE IDEAS OPEN-ENDED
+def generate_next_idea(
+ base_dir,
+ client,
+ model,
+ prev_idea_archive=[],
+ num_reflections=5,
+ max_attempts=10,
+):
+ idea_archive = prev_idea_archive
+ original_archive_size = len(idea_archive)
+
+ print(f"Generating idea {original_archive_size + 1}")
+
+ if len(prev_idea_archive) == 0:
+ print(f"First iteration, taking seed ideas")
+ # seed the archive on the first run with pre-existing ideas
+ with open(osp.join(base_dir, "seed_ideas.json"), "r") as f:
+ seed_ideas = json.load(f)
+ for seed_idea in seed_ideas[:1]:
+ idea_archive.append(seed_idea)
+ else:
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+ idea_system_prompt = prompt["system"]
+
+ # add overview as a condition
+ with open(osp.join(base_dir, "overview.txt"), "r") as f:
+ overview_paragraph = f.read()
+
+ # And include in idea_first_prompt:
+ idea_first_prompt_with_overview = "Current research overview:\n" + overview_paragraph + "\n\n" + idea_first_prompt
+
+ for _ in range(max_attempts):
+ try:
+ idea_strings = []
+ for idea in idea_archive:
+ idea_strings.append(json.dumps(idea))
+ prev_ideas_string = "\n\n".join(idea_strings)
+
+ msg_history = []
+ print(f"Iteration 1/{num_reflections}")
+ text, msg_history = get_response_from_llm(
+ # idea_first_prompt.format(
+ idea_first_prompt_with_overview.format(
+ task_description=prompt["task_description"],
+ code=code,
+ prev_ideas_string=prev_ideas_string,
+ num_reflections=num_reflections,
+ )
+ + """
+Completed ideas have an additional "Score" field which indicates the assessment by an expert ML reviewer.
+This is on a standard 1-10 ML conference scale.
+Scores of 0 indicate the idea failed either during experimentation, writeup or reviewing.
+""",
+ client=client,
+ model=model,
+ system_message=idea_system_prompt,
+ msg_history=msg_history,
+ )
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert json_output is not None, "Failed to extract JSON from LLM output"
+ print(json_output)
+
+ # Iteratively improve task.
+ if num_reflections > 1:
+ for j in range(num_reflections - 1):
+ print(f"Iteration {j + 2}/{num_reflections}")
+ text, msg_history = get_response_from_llm(
+ idea_reflection_prompt.format(
+ current_round=j + 2, num_reflections=num_reflections
+ ),
+ client=client,
+ model=model,
+ system_message=idea_system_prompt,
+ msg_history=msg_history,
+ )
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert (
+ json_output is not None
+ ), "Failed to extract JSON from LLM output"
+ print(json_output)
+
+ if "I am done" in text:
+ print(
+ f"Idea generation converged after {j + 2} iterations."
+ )
+ break
+
+ idea_archive.append(json_output)
+ break
+ except Exception as e:
+ print(f"Failed to generate idea: {e}")
+ continue
+
+ ## SAVE IDEAS
+ with open(osp.join(base_dir, "ideas.json"), "w") as f:
+ json.dump(idea_archive, f, indent=4)
+
+ return idea_archive
+
+
+def on_backoff(details):
+ print(
+ f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
+ f"calling function {details['target'].__name__} at {time.strftime('%X')}"
+ )
+
+
+@backoff.on_exception(
+ backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff
+)
+def search_for_papers(query, result_limit=10, engine="semanticscholar") -> Union[None, List[Dict]]:
+ if not query:
+ return None
+ if engine == "semanticscholar":
+ rsp = requests.get(
+ "https://api.semanticscholar.org/graph/v1/paper/search",
+ headers={"X-API-KEY": S2_API_KEY} if S2_API_KEY else {},
+ params={
+ "query": query,
+ "limit": result_limit,
+ "fields": "title,authors,venue,year,abstract,citationStyles,citationCount",
+ },
)
+ print(f"Response Status Code: {rsp.status_code}")
+ print(
+ f"Response Content: {rsp.text[:500]}"
+ ) # Print the first 500 characters of the response content
+ rsp.raise_for_status()
+ results = rsp.json()
+ total = results["total"]
+ time.sleep(1.0)
+ if not total:
+ return None
- if result.stderr:
- print(result.stderr, file=sys.stderr)
-
- if result.returncode != 0:
- print(f"Run {run_num} failed with return code {result.returncode}")
- if osp.exists(osp.join(cwd, f"run_{run_num}")):
- shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
- print(f"Run failed with the following error {result.stderr}")
- stderr_output = result.stderr
- if len(stderr_output) > MAX_STDERR_OUTPUT:
- stderr_output = "..." + stderr_output[-MAX_STDERR_OUTPUT:]
- next_prompt = f"Run failed with the following error {stderr_output}"
+ papers = results["data"]
+ return papers
+ elif engine == "openalex":
+ import pyalex
+ from pyalex import Work, Works
+ mail = os.environ.get("OPENALEX_MAIL_ADDRESS", None)
+ if mail is None:
+ print("[WARNING] Please set OPENALEX_MAIL_ADDRESS for better access to OpenAlex API!")
else:
- with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f:
- results = json.load(f)
- results = {k: v["means"] for k, v in results.items()}
-
- next_prompt = f"""Run {run_num} completed. Here are the results:
-{results}
-
-Decide if you need to re-plan your experiments given the result (you often will not need to).
-
-Someone else will be using `notes.txt` to perform a writeup on this in the future.
-Please include *all* relevant information for the writeup on Run {run_num}, including an experiment description and the run number. Be as verbose as necessary.
-
-Then, implement the next thing on your list.
-We will then run the command `python experiment.py --out_dir=run_{run_num + 1}'.
-YOUR PROPOSED CHANGE MUST USE THIS COMMAND FORMAT, DO NOT ADD ADDITIONAL COMMAND LINE ARGS.
-If you are finished with experiments, respond with 'ALL_COMPLETED'."""
- return result.returncode, next_prompt
- except TimeoutExpired:
- print(f"Run {run_num} timed out after {timeout} seconds")
- if osp.exists(osp.join(cwd, f"run_{run_num}")):
- shutil.rmtree(osp.join(cwd, f"run_{run_num}"))
- next_prompt = f"Run timed out after {timeout} seconds"
- return 1, next_prompt
-
-
-# RUN PLOTTING
-def run_plotting(folder_name, timeout=600):
- cwd = osp.abspath(folder_name)
- # LAUNCH COMMAND
- command = [
- "python",
- "plot.py",
- ]
- try:
- result = subprocess.run(
- command, cwd=cwd, stderr=subprocess.PIPE, text=True, timeout=timeout
- )
+ pyalex.config.email = mail
- if result.stderr:
- print(result.stderr, file=sys.stderr)
+ def extract_info_from_work(work: Work, max_abstract_length: int = 1000) -> dict[str, str]:
+ # "Unknown" is returned when venue is unknown...
+ venue = "Unknown"
+ for i, location in enumerate(work["locations"]):
+ if location["source"] is not None:
+ venue = location["source"]["display_name"]
+ if venue != "":
+ break
+ title = work["title"]
+ abstract = work["abstract"]
+ if abstract is None:
+ abstract = ""
+ if len(abstract) > max_abstract_length:
+ # To avoid context length exceed error.
+ print(f"[WARNING] {title=}: {len(abstract)=} is too long! Use first {max_abstract_length} chars.")
+ abstract = abstract[:max_abstract_length]
+ authors_list = [author["author"]["display_name"] for author in work["authorships"]]
+ authors = " and ".join(authors_list) if len(authors_list) < 20 else f"{authors_list[0]} et al."
+ paper = dict(
+ title=title,
+ authors=authors,
+ venue=venue,
+ year=work["publication_year"],
+ abstract=abstract,
+ citationCount=work["cited_by_count"],
+ )
+ return paper
- if result.returncode != 0:
- print(f"Plotting failed with return code {result.returncode}")
- next_prompt = f"Plotting failed with the following error {result.stderr}"
- else:
- next_prompt = ""
- return result.returncode, next_prompt
- except TimeoutExpired:
- print(f"Plotting timed out after {timeout} seconds")
- next_prompt = f"Plotting timed out after {timeout} seconds"
- return 1, next_prompt
-
-
-# PERFORM EXPERIMENTS
-def perform_experiments(idea, folder_name, coder, baseline_results) -> bool:
- ## RUN EXPERIMENT
- current_iter = 0
- run = 1
- next_prompt = coder_prompt.format(
- title=idea["Title"],
- idea=idea["Experiment"],
- max_runs=MAX_RUNS,
- baseline_results=baseline_results,
- )
- while run < MAX_RUNS + 1:
- if current_iter >= MAX_ITERS:
- print("Max iterations reached")
- break
- coder_out = coder.run(next_prompt)
- print(coder_out)
- if "ALL_COMPLETED" in coder_out:
- break
- return_code, next_prompt = run_experiment(folder_name, run)
- if return_code == 0:
- run += 1
- current_iter = 0
- current_iter += 1
- if current_iter >= MAX_ITERS:
- print("Not all experiments completed.")
- return False
-
- current_iter = 0
- next_prompt = """
-Great job! Please modify `plot.py` to generate the most relevant plots for the final writeup.
-
-In particular, be sure to fill in the "labels" dictionary with the correct names for each run that you want to plot.
-
-Only the runs in the `labels` dictionary will be plotted, so make sure to include all relevant runs.
-
-We will be running the command `python plot.py` to generate the plots.
+ works: List[Dict] = Works().search(query).get(per_page=result_limit)
+ papers: List[Dict[str, str]] = [extract_info_from_work(work) for work in works]
+ return papers
+ else:
+ raise NotImplementedError(f"{engine=} not supported!")
+
+
+
+novelty_system_msg = """You are an ambitious AI PhD student who is looking to publish a paper that will contribute significantly to the field.
+You have an idea and you want to check if it is novel or not. I.e., not overlapping significantly with existing literature or already well explored.
+Be a harsh critic for novelty, ensure there is a sufficient contribution in the idea for a new conference or workshop paper.
+You will be given access to the Semantic Scholar API, which you may use to survey the literature and find relevant papers to help you make your decision.
+The top 10 results for any search query will be presented to you with the abstracts.
+
+You will be given {num_rounds} to decide on the paper, but you do not need to use them all.
+At any round, you may exit early and decide on the novelty of the idea.
+Decide a paper idea is novel if after sufficient searching, you have not found a paper that significantly overlaps with your idea.
+Decide a paper idea is not novel, if you have found a paper that significantly overlaps with your idea.
+
+{task_description}
+
+{code}
+
"""
- while True:
- _ = coder.run(next_prompt)
- return_code, next_prompt = run_plotting(folder_name)
- current_iter += 1
- if return_code == 0 or current_iter >= MAX_ITERS:
- break
- next_prompt = """
-Please modify `notes.txt` with a description of what each plot shows along with the filename of the figure. Please do so in-depth.
-
-Somebody else will be using `notes.txt` to write a report on this in the future.
+
+novelty_prompt = '''Round {current_round}/{num_rounds}.
+You have this idea:
+
+"""
+{idea}
+"""
+
+The results of the last query are (empty on first round):
"""
- coder.run(next_prompt)
+{last_query_results}
+"""
+
+Respond in the following format:
+
+THOUGHT:
+
+
+RESPONSE:
+```json
+
+```
+
+In , first briefly reason over the idea and identify any query that could help you make your decision.
+If you have made your decision, add "Decision made: novel." or "Decision made: not novel." to your thoughts.
+
+In , respond in JSON format with ONLY the following field:
+- "Query": An optional search query to search the literature (e.g. attention is all you need). You must make a query if you have not decided this round.
- return True
+A query will work best if you are able to recall the exact name of the paper you are looking for, or the authors.
+This JSON will be automatically parsed, so ensure the format is precise.'''
+
+
+def check_idea_novelty(
+ ideas,
+ base_dir,
+ client,
+ model,
+ max_num_iterations=10,
+ engine="semanticscholar",
+):
+ with open(osp.join(base_dir, "experiment.py"), "r") as f:
+ code = f.read()
+ with open(osp.join(base_dir, "prompt.json"), "r") as f:
+ prompt = json.load(f)
+ task_description = prompt["task_description"]
+
+ for idx, idea in enumerate(ideas):
+ if "novel" in idea:
+ print(f"Skipping idea {idx}, already checked.")
+ continue
+
+ print(f"\nChecking novelty of idea {idx}: {idea['Name']}")
+
+ novel = False
+ msg_history = []
+ papers_str = ""
+
+ for j in range(max_num_iterations):
+ try:
+ text, msg_history = get_response_from_llm(
+ novelty_prompt.format(
+ current_round=j + 1,
+ num_rounds=max_num_iterations,
+ idea=idea,
+ last_query_results=papers_str,
+ ),
+ client=client,
+ model=model,
+ system_message=novelty_system_msg.format(
+ num_rounds=max_num_iterations,
+ task_description=task_description,
+ code=code,
+ ),
+ msg_history=msg_history,
+ )
+ if "decision made: novel" in text.lower():
+ print("Decision made: novel after round", j)
+ novel = True
+ break
+ if "decision made: not novel" in text.lower():
+ print("Decision made: not novel after round", j)
+ break
+
+ ## PARSE OUTPUT
+ json_output = extract_json_between_markers(text)
+ assert json_output is not None, "Failed to extract JSON from LLM output"
+
+ ## SEARCH FOR PAPERS
+ query = json_output["Query"]
+ papers = search_for_papers(query, result_limit=10, engine=engine)
+ if papers is None:
+ papers_str = "No papers found."
+
+ paper_strings = []
+ for i, paper in enumerate(papers):
+ paper_strings.append(
+ """{i}: {title}. {authors}. {venue}, {year}.\nNumber of citations: {cites}\nAbstract: {abstract}""".format(
+ i=i,
+ title=paper["title"],
+ authors=paper["authors"],
+ venue=paper["venue"],
+ year=paper["year"],
+ cites=paper["citationCount"],
+ abstract=paper["abstract"],
+ )
+ )
+ papers_str = "\n\n".join(paper_strings)
+
+ except Exception as e:
+ print(f"Error: {e}")
+ continue
+
+ idea["novel"] = novel
+
+ # Save results to JSON file
+ results_file = osp.join(base_dir, "ideas.json")
+ with open(results_file, "w") as f:
+ json.dump(ideas, f, indent=4)
+
+ return ideas
+
+
+if __name__ == "__main__":
+ MAX_NUM_GENERATIONS = 32
+ NUM_REFLECTIONS = 5
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Generate AI scientist ideas")
+ # add type of experiment (nanoGPT, Boston, etc.)
+ parser.add_argument(
+ "--experiment",
+ type=str,
+ default="nanoGPT",
+ help="Experiment to run AI Scientist on.",
+ )
+ parser.add_argument(
+ "--model",
+ type=str,
+ default="gpt-4o-2024-05-13",
+ choices=AVAILABLE_LLMS,
+ help="Model to use for AI Scientist.",
+ )
+ parser.add_argument(
+ "--skip-idea-generation",
+ action="store_true",
+ help="Skip idea generation and use existing ideas.",
+ )
+ parser.add_argument(
+ "--check-novelty",
+ action="store_true",
+ help="Check novelty of ideas.",
+ )
+ args = parser.parse_args()
+
+ # Create client
+ client, client_model = create_client(args.model)
+
+ base_dir = osp.join("templates", args.experiment)
+ results_dir = osp.join("results", args.experiment)
+
+ overview = generate_overview_paragraph(base_dir, client, client_model)
+
+ ideas = generate_ideas(
+ base_dir,
+ client=client,
+ model=client_model,
+ skip_generation=args.skip_idea_generation,
+ max_num_generations=MAX_NUM_GENERATIONS,
+ num_reflections=NUM_REFLECTIONS,
+ )
+ if args.check_novelty:
+ ideas = check_idea_novelty(
+ ideas,
+ base_dir=base_dir,
+ client=client,
+ model=client_model,
+ )
diff --git a/ai_scientist/perform_writeup.py b/ai_scientist/perform_writeup.py
index 8fe07cb7..475b4a89 100644
--- a/ai_scientist/perform_writeup.py
+++ b/ai_scientist/perform_writeup.py
@@ -17,9 +17,49 @@ def generate_latex(coder, folder_name, pdf_file, timeout=30, num_error_correctio
cwd = osp.join(folder, "latex") # Fixed potential issue with path
writeup_file = osp.join(cwd, "template.tex")
- # Check all references are valid and in the references.bib file
with open(writeup_file, "r") as f:
tex_text = f.read()
+
+ # re-format tex file
+ if "\\begin{abstract}" in tex_text and "\\begin{document}" not in tex_text:
+ preamble = (
+ "\\documentclass{article}\n"
+ "\\usepackage{graphicx}\n"
+ "\\usepackage{natbib}\n"
+ "\\usepackage{amsmath, amssymb}\n"
+ "\\usepackage{geometry}\n"
+ )
+
+ if "\\maketitle" not in tex_text:
+ tex_text = re.sub(
+ r"(\\title\{.*?\})",
+ r"\1\n\\date{}\n\\maketitle",
+ tex_text,
+ flags=re.DOTALL
+ )
+
+ tex_text = (
+ preamble
+ + "\n\\begin{document}\n"
+ + tex_text
+ + "\\bibliographystyle{plainnat}\n"
+ + "\\bibliography{references}\n"
+ + "\n\\end{document}\n"
+ )
+
+ with open(writeup_file, "w") as f:
+ f.write(tex_text)
+
+ # copy plot image files
+ png_files = [f for f in os.listdir(folder) if f.endswith(".png")]
+ for file in png_files:
+ src = os.path.join(folder, file)
+ dst = os.path.join(cwd, file)
+ if not os.path.exists(dst):
+ shutil.copy(src, dst)
+ print(f"Copied image: {file} latex/")
+
+ # Check all references are valid and in the references.bib file
cites = re.findall(r"\\cite[a-z]*{([^}]*)}", tex_text)
references_bib = re.search(
r"\\begin{filecontents}{references.bib}(.*?)\\end{filecontents}",
@@ -419,6 +459,8 @@ def perform_writeup(
.replace(r"{{", "{")
.replace(r"}}", "}")
)
+
+ # each section
for section in [
"Introduction",
"Background",
@@ -462,23 +504,64 @@ def perform_writeup(
coder_out = coder.run(section_prompt)
# Fill paper with cites.
+ print("************ CITATION CHECKING START ... ************")
for _ in range(num_cite_rounds):
- with open(osp.join(folder_name, "latex", "template.tex"), "r") as f:
+ writeup_file = osp.join(folder_name, "latex", "template.tex")
+ with open(writeup_file, "r") as f:
draft = f.read()
+
prompt, done = get_citation_aider_prompt(
cite_client, cite_model, draft, _, num_cite_rounds, engine=engine
)
+
if done:
break
if prompt is not None:
+ # # extract bibtex string
+ # bibtex_string = prompt.split('"""')[1]
+ # # insert this into draft before the "\end{filecontents}" line
+ # search_str = r"\end{filecontents}"
+ # draft = draft.replace(search_str, f"{bibtex_string}{search_str}")
+ # with open(osp.join(folder_name, "latex", "template.tex"), "w") as f:
+ # f.write(draft)
+
# extract bibtex string
- bibtex_string = prompt.split('"""')[1]
+ try:
+ bibtex_string = prompt.split('"""')[1].strip()
+ print(f"******** bibtex_string:\n{bibtex_string}\n **********")
+ except IndexError:
+ print("Failed to extract bibtex string from prompt.")
+ continue
+
# insert this into draft before the "\end{filecontents}" line
- search_str = r"\end{filecontents}"
- draft = draft.replace(search_str, f"{bibtex_string}{search_str}")
- with open(osp.join(folder_name, "latex", "template.tex"), "w") as f:
+ with open(writeup_file, "r") as f:
+ draft = f.read()
+
+ if r"\end{filecontents}" in draft:
+ draft = draft.replace(r"\end{filecontents}", f"{bibtex_string}\n\\end{{filecontents}}")
+ print("Inserted bibtex into existing references.bib block.")
+ else:
+ print("No \\end{filecontents} found. Creating new references.bib block.")
+
+ bib_block = f"""
+\\begin{{filecontents}}{{references.bib}}
+{bibtex_string}
+\\end{{filecontents}}
+""".lstrip()
+
+ if "\\documentclass" in draft:
+ draft = draft.replace("\\documentclass", bib_block + "\n\\documentclass", 1)
+ else:
+ draft = bib_block + draft # fallback
+
+ print("Inserted new references.bib block.")
+
+ with open(writeup_file, "w") as f:
f.write(draft)
+
+ # apply citation edit to the document
coder_out = coder.run(prompt)
+ print("************ CITATION CHECKING END !!! ************")
coder_out = coder.run(
refinement_prompt.format(section="Related Work")
diff --git a/launch_scientist.py b/launch_scientist.py
index 2fe7a49c..4f86ba5f 100644
--- a/launch_scientist.py
+++ b/launch_scientist.py
@@ -13,7 +13,7 @@
from aider.models import Model
from datetime import datetime
-from ai_scientist.generate_ideas import generate_ideas, check_idea_novelty
+from ai_scientist.generate_ideas import generate_overview_paragraph, generate_ideas, check_idea_novelty
from ai_scientist.llm import create_client, AVAILABLE_LLMS
from ai_scientist.perform_experiments import perform_experiments
from ai_scientist.perform_review import perform_review, load_paper, perform_improvement
@@ -171,9 +171,11 @@ def do_idea(
shutil.copytree(base_dir, destination_dir, dirs_exist_ok=True)
with open(osp.join(base_dir, "run_0", "final_info.json"), "r") as f:
baseline_results = json.load(f)
+
# Check if baseline_results is a dictionary before extracting means
if isinstance(baseline_results, dict):
baseline_results = {k: v["means"] for k, v in baseline_results.items()}
+
exp_file = osp.join(folder_name, "experiment.py")
vis_file = osp.join(folder_name, "plot.py")
notes = osp.join(folder_name, "notes.txt")
@@ -193,6 +195,7 @@ def do_idea(
try:
print_time()
print(f"*Starting idea: {idea_name}*")
+
## PERFORM EXPERIMENTS
fnames = [exp_file, vis_file, notes]
io = InputOutput(
@@ -340,6 +343,14 @@ def do_idea(
base_dir = osp.join("templates", args.experiment)
results_dir = osp.join("results", args.experiment)
+
+ # Idea generation
+ print("====================== IDEA EXPLORATION START... ======================")
+ overview = generate_overview_paragraph(base_dir, client, client_model)
+
+ print("====================== IDEA EXPLORATION END!!! ======================")
+
+ print("====================== IDEA GENERATION START... ======================")
ideas = generate_ideas(
base_dir,
client=client,
@@ -348,6 +359,10 @@ def do_idea(
max_num_generations=args.num_ideas,
num_reflections=NUM_REFLECTIONS,
)
+ print("====================== IDEA GENERATION END!!! ======================")
+ print(ideas)
+
+ print("====================== IDEA CHECKING START... ======================")
if not args.skip_novelty_check:
ideas = check_idea_novelty(
ideas,
@@ -362,7 +377,11 @@ def do_idea(
novel_ideas = [idea for idea in ideas if idea["novel"]]
# novel_ideas = list(reversed(novel_ideas))
+ print("====================== IDEA CHECKING END!!! ======================")
+ print(novel_ideas)
+ # Experiment
+ print("====================== EXPERIMENT START... ======================")
if args.parallel > 0:
print(f"Running {args.parallel} parallel processes")
queue = multiprocessing.Queue()
@@ -418,3 +437,4 @@ def do_idea(
import traceback
print(traceback.format_exc())
print("All ideas evaluated.")
+ print("====================== EXPERIMENT END !!! ======================")
diff --git a/requirements.txt b/requirements.txt
index 8971848d..fcc8f1b3 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
# LLM APIs
anthropic
-aider-chat
+aider-chat==0.82.2
backoff
-openai
+openai==1.73.1
google-generativeai
# Viz
matplotlib
diff --git a/templates/eeg_classification/dataset.sh b/templates/eeg_classification/dataset.sh
new file mode 100644
index 00000000..0d202771
--- /dev/null
+++ b/templates/eeg_classification/dataset.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+mkdir -p data
+
+for subj in A01 A02 A03 A04 A05 A06 A07 A08 A09
+do
+ for sess in T E
+ do
+ fname="${subj}${sess}.mat"
+ url="https://bnci-horizon-2020.eu/database/data-sets/001-2014/$fname"
+ wget -O "./data/$fname" "$url"
+ done
+done
\ No newline at end of file
diff --git a/templates/eeg_classification/experiment.py b/templates/eeg_classification/experiment.py
new file mode 100644
index 00000000..97ed7eba
--- /dev/null
+++ b/templates/eeg_classification/experiment.py
@@ -0,0 +1,214 @@
+import os
+import json
+import time
+import pickle
+import argparse
+import numpy as np
+from tqdm import tqdm
+from sklearn.metrics import confusion_matrix, accuracy_score
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.utils.data import DataLoader, random_split
+from torcheeg.datasets import BCICIV2aDataset
+from torcheeg import transforms
+
+TRAIN_SUBJECTS = ['A01', 'A02', 'A03', 'A04', 'A05', 'A06', 'A07']
+TEST_SUBJECTS = ['A08', 'A09']
+
+# https://torcheeg.readthedocs.io/en/latest/generated/torcheeg.datasets.BCICIV2aDataset.html
+class SubjectFilteredBCICIV2aDataset(BCICIV2aDataset):
+ def __init__(self, subjects: list, *args, **kwargs):
+ self.subjects = subjects
+ super().__init__(*args, **kwargs)
+
+ def set_records(self, root_path: str, **kwargs):
+ all_files = super().set_records(root_path)
+ return [f for f in all_files if os.path.basename(f)[:3] in self.subjects]
+
+# https://arxiv.org/pdf/1611.08024
+class EEGNet(nn.Module):
+ def __init__(self, num_classes=4, num_channels=22, num_samples=1750,
+ dropout_rate=0.5, kernel_length=64, F1=8, D=2, F2=16):
+ super(EEGNet, self).__init__()
+ self.first_conv = nn.Sequential(
+ nn.Conv2d(1, F1, (1, kernel_length), padding=(0, kernel_length // 2), bias=False),
+ nn.BatchNorm2d(F1)
+ )
+
+ self.second_conv = nn.Sequential(
+ nn.Conv2d(F1, F1 * D, (num_channels, 1), groups=F1, bias=False),
+ nn.BatchNorm2d(F1 * D),
+ nn.ELU(),
+ nn.AvgPool2d((1, 4)),
+ nn.Dropout(dropout_rate)
+ )
+
+ self.separable_conv = nn.Sequential(
+ nn.Conv2d(F1 * D, F1 * D, kernel_size=(1, 16), padding=(0, 8), groups=F1 * D, bias=False), # depth-wise
+ nn.Conv2d(F1 * D, F2, kernel_size=(1, 1), bias=False), # point-wise: 1x1 conv to combine depthwise outputs
+ nn.BatchNorm2d(F2),
+ nn.ELU(),
+ nn.AvgPool2d(kernel_size=(1, 8)),
+ nn.Dropout(dropout_rate)
+ )
+
+ self.classifier = nn.Sequential(
+ nn.Flatten(),
+ nn.Linear(F2 * ((num_samples // 32)), num_classes)
+ )
+
+ def forward(self, x):
+ x = self.first_conv(x)
+ x = self.second_conv(x)
+ x = self.separable_conv(x)
+ x = self.classifier(x)
+ return x
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Run experiment")
+ parser.add_argument("--out_dir", type=str, default="run_0", help="Output directory")
+ args = parser.parse_args()
+
+ # reproducibility
+ torch.manual_seed(42)
+ np.random.seed(42)
+
+ if torch.cuda.is_available():
+ print("Using GPU")
+ else:
+ print("Using CPU")
+
+ # logging directory
+ os.makedirs(args.out_dir, exist_ok=True)
+
+ # dataset
+ train_dataset = SubjectFilteredBCICIV2aDataset(
+ subjects=TRAIN_SUBJECTS,
+ root_path='./data',
+ online_transform=transforms.Compose([
+ transforms.To2d(),
+ transforms.ToTensor()
+ ]),
+ label_transform=transforms.Compose([
+ transforms.Select('label'),
+ transforms.Lambda(lambda x: x - 1)
+ ])
+ )
+
+ train_size = int(0.8 * len(train_dataset))
+ val_size = len(train_dataset) - train_size
+ train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])
+ train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
+ val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)
+
+ # main Loop
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+ model = EEGNet(num_classes=4, kernel_length=125).to(device)
+
+ criterion = nn.CrossEntropyLoss()
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ train_losses = []
+ val_accuracies = []
+
+ start_time = time.time()
+ for epoch in tqdm(range(100)):
+ model.train()
+ total_loss = 0
+ for X, y in train_loader:
+ X, y = X.to(device), y.to(device)
+ optimizer.zero_grad()
+ logits = model(X)
+ loss = criterion(logits, y)
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.item()
+
+ train_losses.append(total_loss)
+
+ model.eval()
+ correct, total = 0, 0
+ all_preds, all_labels = [], []
+ with torch.no_grad():
+ for X, y in val_loader:
+ X, y = X.to(device), y.to(device)
+ logits = model(X)
+ preds = logits.argmax(dim=1)
+ all_preds.append(preds.cpu().numpy())
+ all_labels.append(y.cpu().numpy())
+ correct += (preds == y).sum().item()
+ total += y.size(0)
+
+ acc = correct / total
+ val_accuracies.append(acc)
+
+ end_time = time.time()
+ training_time = end_time - start_time
+
+ # test the model
+ test_dataset = SubjectFilteredBCICIV2aDataset(
+ subjects=TEST_SUBJECTS,
+ root_path='./data/BCICIV_2a_mat_evaluate',
+ online_transform=transforms.Compose([
+ transforms.To2d(),
+ transforms.ToTensor()
+ ]),
+ label_transform=transforms.Compose([
+ transforms.Select('label'),
+ transforms.Lambda(lambda x: x - 1)
+ ])
+ )
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
+
+ results = []
+ model.eval()
+ with torch.no_grad():
+ for X, y in tqdm(test_loader, desc="Evaluating"):
+ X = X.to(device)
+ logits = model(X)
+ preds = logits.argmax(dim=1).cpu().tolist()
+ labels = y.tolist()
+ for p, l in zip(preds, labels):
+ results.append({"pred": int(p), "label": int(l)})
+
+ labels = [r["label"] for r in results]
+ preds = [r["pred"] for r in results]
+ test_acc = accuracy_score(labels, preds)
+ conf_mat = confusion_matrix(labels, preds).tolist()
+
+ # save the outputs (structure for a single run)
+ final_info = {
+ "eegnet": {
+ "means": {
+ "training_time": training_time,
+ "final_val_accuracy": val_accuracies[-1],
+ "test_accuracy": test_acc,
+ },
+ "stderrs": {
+ "training_time_stderr": 0.0,
+ "final_val_accuracy_stderr": 0.0,
+ "test_accuracy_stderr": 0.0,
+ },
+ "final_info_dict": {
+ "training_time": [training_time],
+ "final_val_accuracy": [val_accuracies[-1]],
+ "test_accuracy": [test_acc],
+ }
+ }
+ }
+ with open(os.path.join(args.out_dir, "final_info.json"), "w") as f:
+ json.dump(final_info, f, indent=2)
+
+ all_results = {
+ "eegnet_train_losses": train_losses,
+ "eegnet_val_accuracies": val_accuracies,
+ "eegnet_test_results": results,
+ "eegnet_confusion_matrix": conf_mat
+ }
+ with open(os.path.join(args.out_dir, "all_results.pkl"), "wb") as f:
+ pickle.dump(all_results, f)
+
+ torch.save(model.state_dict(), os.path.join(args.out_dir, "model.pth"))
\ No newline at end of file
diff --git a/templates/eeg_classification/ideas.json b/templates/eeg_classification/ideas.json
new file mode 100644
index 00000000..e69de29b
diff --git a/templates/eeg_classification/plot.py b/templates/eeg_classification/plot.py
new file mode 100644
index 00000000..36dc2a5d
--- /dev/null
+++ b/templates/eeg_classification/plot.py
@@ -0,0 +1,107 @@
+import os
+import json
+import pickle
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
+
+
+def get_run_dirs(base_dir='.'):
+ return sorted([
+ d for d in os.listdir(base_dir)
+ if d.startswith('run') and os.path.isdir(os.path.join(base_dir, d))
+ ])
+
+
+def plot_training_curves(run_dirs):
+ fig, ax1 = plt.subplots(figsize=(10, 6))
+ ax2 = ax1.twinx()
+ colors = plt.cm.tab10(np.linspace(0, 1, len(run_dirs)))
+
+ for idx, run in enumerate(run_dirs):
+ pkl_path = os.path.join(run, 'all_results.pkl')
+ if not os.path.exists(pkl_path):
+ print(f"[{run}] missing all_results.pkl")
+ continue
+
+ with open(pkl_path, 'rb') as f:
+ results = pickle.load(f)
+
+ train_losses = results.get('eegnet_train_losses', [])
+ val_accuracies = np.array(results.get('eegnet_val_accuracies', [])) * 100
+
+ ax1.plot(train_losses, label=f'{run} - train loss', color=colors[idx], linestyle='-')
+ ax2.plot(val_accuracies, label=f'{run} - val acc', color=colors[idx], linestyle='--')
+
+ ax1.set_xlabel("Epoch")
+ ax1.set_ylabel("Train Loss", color='blue')
+ ax2.set_ylabel("Validation Accuracy (%)", color='green')
+
+ ax1.tick_params(axis='y', labelcolor='blue')
+ ax2.tick_params(axis='y', labelcolor='green')
+
+ lines_1, labels_1 = ax1.get_legend_handles_labels()
+ lines_2, labels_2 = ax2.get_legend_handles_labels()
+ plt.legend(lines_1 + lines_2, labels_1 + labels_2, loc="center right")
+
+ plt.title("Train Loss & Validation Accuracy")
+ plt.grid(True)
+ plt.tight_layout()
+ plt.savefig("plot_training_curves.png")
+ plt.close()
+
+
+def analyze_test_results(run_dirs):
+ test_accs = []
+
+ for run in run_dirs:
+ final_info_path = os.path.join(run, 'final_info.json')
+ all_results_path = os.path.join(run, 'all_results.pkl')
+
+ if not (os.path.exists(final_info_path) and os.path.exists(all_results_path)):
+ print(f"[{run}] missing result files")
+ continue
+
+ with open(final_info_path, 'r') as f:
+ final_info = json.load(f)
+ test_acc = final_info['eegnet']['means']['test_accuracy']
+
+ with open(all_results_path, 'rb') as f:
+ results = pickle.load(f)
+ test_results = results.get('eegnet_test_results', [])
+
+ y_true = [x['label'] for x in test_results]
+ y_pred = [x['pred'] for x in test_results]
+
+ cm = confusion_matrix(y_true, y_pred)
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm)
+ disp.plot(cmap='Blues', values_format='d')
+ plt.title(f"Confusion Matrix – {run}\nTest Accuracy: {test_acc * 100:.2f}%")
+ plt.tight_layout()
+ plt.savefig(f"plot_confusion_matrix_{run}.png")
+ plt.close()
+
+ print(f"[{run}] Test accuracy: {test_acc * 100:.2f}%")
+ test_accs.append((run, test_acc))
+
+ if test_accs:
+ test_accs.sort(key=lambda x: x[0])
+ run_names, accs = zip(*test_accs)
+
+ plt.figure(figsize=(8, 5))
+ plt.bar(run_names, [a * 100 for a in accs], color='teal')
+ plt.title("Test Accuracy by Run")
+ plt.ylabel("Accuracy (%)")
+ plt.ylim(0, 100)
+ plt.grid(axis='y', linestyle='--', alpha=0.5)
+ plt.tight_layout()
+ plt.savefig("plot_test_accuracy_by_run.png")
+ plt.close()
+
+
+if __name__ == "__main__":
+ run_dirs = get_run_dirs()
+ print(f"Found {len(run_dirs)} runs: {run_dirs}")
+
+ plot_training_curves(run_dirs)
+ analyze_test_results(run_dirs)
diff --git a/templates/eeg_classification/prompt.json b/templates/eeg_classification/prompt.json
new file mode 100644
index 00000000..354c68dc
--- /dev/null
+++ b/templates/eeg_classification/prompt.json
@@ -0,0 +1,4 @@
+{
+ "system": "You are an ambitious AI researcher who is looking to publish a paper that will contribute significantly to the field.",
+ "task_description": "You are given the following file to work with, which trains a neural network model which predicts the class of motor by capturing brain activity, i.e., electroencephalography (EEG) signals. Your objective is to design and optimize a novel representation learning methods that can accurately classify mental states and generalize across different subjects."
+}
\ No newline at end of file
diff --git a/templates/eeg_classification/readme.md b/templates/eeg_classification/readme.md
new file mode 100644
index 00000000..8235b861
--- /dev/null
+++ b/templates/eeg_classification/readme.md
@@ -0,0 +1,39 @@
+# EEG Classification with The AI Scientist
+
+This repository demonstrates a novel technique for EEG classification, discovered by [The AI Scientist](https://github.com/SakanaAI/AI-Scientist).
+
+## Result
+
+The paper generated by The AI Scientist can be found in this [link](https://drive.google.com/file/d/1Uov-TPKi9T-u-VcHyJgoRCDNr1_IgXvO/view?usp=sharing)
+
+## Prepare the dataset
+
+This template exploits a dataset for motor imagery, named BCICIV2aDataset, used for BCI Competition 2008 Graz.
+More details can be found in this [URL](https://torcheeg.readthedocs.io/en/latest/generated/torcheeg.datasets.BCICIV2aDataset.html).
+
+```bash
+pip install -r requirements.txt
+bash dataset.sh
+```
+
+## Run the template
+
+For the baseline, this template implements EEGNet, introduced in this [paper](https://arxiv.org/abs/1611.08024).
+
+```bash
+# train the model
+python experiment.py --out_dir run_0
+
+# generate visualization plots
+python plot.py
+```
+
+## Launch The AI Scientist
+
+```bash
+# launch The AI Scientist
+python launch_scientist.py \
+ --model "gpt-4o-2024-08-06" \
+ --experiment eeg_classification \
+ --num-ideas 1
+```
\ No newline at end of file
diff --git a/templates/eeg_classification/requirements.txt b/templates/eeg_classification/requirements.txt
new file mode 100644
index 00000000..e5cb2ed8
--- /dev/null
+++ b/templates/eeg_classification/requirements.txt
@@ -0,0 +1 @@
+torcheeg
\ No newline at end of file
diff --git a/templates/eeg_classification/seed_ideas.json b/templates/eeg_classification/seed_ideas.json
new file mode 100644
index 00000000..382ffa15
--- /dev/null
+++ b/templates/eeg_classification/seed_ideas.json
@@ -0,0 +1,10 @@
+[
+ {
+ "Name": "eegnet",
+ "Title": "EEGNet: A Compact Convolutional Neural Network for EEG-based Brain-Computer Interfaces.",
+ "Experiment": "This experiment evaluates the EEGNet architecture on the BCI Competition IV 2a dataset using a train/test split across subjects. The final evaulation is reported using test accuracy, along with confusion matrices and training time statistics.",
+ "Interestingness": 5,
+ "Feasibility": 9,
+ "Novelty": 4
+ }
+]