Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 18 additions & 18 deletions ai_scientist/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import anthropic
import backoff
import openai
import google.generativeai as genai
from google.generativeai.types import GenerationConfig
import google.genai as genai
from google.genai.types import GenerateContentConfig

MAX_NUM_TOKENS = 4096

Expand Down Expand Up @@ -59,6 +59,8 @@
"gemini-2.0-flash-thinking-exp-01-21",
"gemini-2.5-pro-preview-03-25",
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-flash",
"gemini-2.5-pro"
]


Expand Down Expand Up @@ -148,6 +150,7 @@ def get_response_from_llm(
print_debug=False,
msg_history=None,
temperature=0.75,
tools = [],
):
if msg_history is None:
msg_history = []
Expand Down Expand Up @@ -258,19 +261,17 @@ def get_response_from_llm(
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
elif "gemini" in model:
new_msg_history = msg_history + [{"role": "user", "content": msg}]
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_message},
*new_msg_history,
],
temperature=temperature,
max_tokens=MAX_NUM_TOKENS,
n=1,
chat = client.chats.create(model=model, history=msg_history)
response = chat.send_message(
msg,
config=GenerateContentConfig(
system_instruction=system_message,
max_output_tokens=MAX_NUM_TOKENS,
temperature=temperature
)
)
content = response.choices[0].message.content
new_msg_history = new_msg_history + [{"role": "assistant", "content": content}]
content = response.text
new_msg_history = chat.get_history()
else:
raise ValueError(f"Model {model} not supported.")

Expand Down Expand Up @@ -342,10 +343,9 @@ def create_client(model):
base_url="https://openrouter.ai/api/v1"
), "meta-llama/llama-3.1-405b-instruct"
elif "gemini" in model:
print(f"Using OpenAI API with {model}.")
return openai.OpenAI(
api_key=os.environ["GEMINI_API_KEY"],
base_url="https://generativelanguage.googleapis.com/v1beta/openai/"
print(f"Using Google GenAI API with {model}.")
return genai.client.Client(
api_key=os.environ["GEMINI_API_KEY"]
), model
else:
raise ValueError(f"Model {model} not supported.")
4 changes: 3 additions & 1 deletion ai_scientist/perform_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def run_experiment(folder_name, run_num, timeout=7200):
else:
with open(osp.join(cwd, f"run_{run_num}", "final_info.json"), "r") as f:
results = json.load(f)
results = {k: v["means"] for k, v in results.items()}
if isinstance(results, dict) and \
all([v.get("means") for v in results.values()]):
results = {k: v["means"] for k, v in results.items()}

next_prompt = f"""Run {run_num} completed. Here are the results:
{results}
Expand Down
22 changes: 20 additions & 2 deletions launch_scientist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ def print_time():
print(datetime.now().strftime("%Y-%m-%d %H:%M:%S"))


def comma_separated_list(string):
if not string:
return []
return [item.strip() for item in string.split(',')]

def parse_arguments():
parser = argparse.ArgumentParser(description="Run AI scientist experiments")
parser.add_argument(
Expand Down Expand Up @@ -89,6 +94,12 @@ def parse_arguments():
choices=["semanticscholar", "openalex"],
help="Scholar engine to use.",
)
parser.add_argument(
"--per-experiment-files",
type=comma_separated_list,
default=[],
help="A list of files to be inlucded in addition to experiment.py",
)
return parser.parse_args()


Expand Down Expand Up @@ -161,6 +172,7 @@ def do_idea(
writeup,
improvement,
log_file=False,
per_experiment_files = [],
):
## CREATE PROJECT FOLDER
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
Expand All @@ -172,11 +184,13 @@ def do_idea(
with open(osp.join(base_dir, "run_0", "final_info.json"), "r") as f:
baseline_results = json.load(f)
# Check if baseline_results is a dictionary before extracting means
if isinstance(baseline_results, dict):
if isinstance(baseline_results, dict) and \
all([v.get("means") for v in baseline_results.values()]):
baseline_results = {k: v["means"] for k, v in baseline_results.items()}
exp_file = osp.join(folder_name, "experiment.py")
vis_file = osp.join(folder_name, "plot.py")
notes = osp.join(folder_name, "notes.txt")
per_experiment_files = [osp.join(folder_name, f) for f in per_experiment_files]
with open(notes, "w") as f:
f.write(f"# Title: {idea['Title']}\n")
f.write(f"# Experiment description: {idea['Experiment']}\n")
Expand All @@ -194,7 +208,7 @@ def do_idea(
print_time()
print(f"*Starting idea: {idea_name}*")
## PERFORM EXPERIMENTS
fnames = [exp_file, vis_file, notes]
fnames = [exp_file, vis_file, notes] + per_experiment_files
io = InputOutput(
yes=True, chat_history_file=f"{folder_name}/{idea_name}_aider.txt"
)
Expand Down Expand Up @@ -259,6 +273,9 @@ def do_idea(
else:
raise ValueError(f"Writeup format {writeup} not supported.")

print("stop before reviewing")
sys.exit(1)

print_time()
print(f"*Starting Review*")
## REVIEW PAPER
Expand Down Expand Up @@ -411,6 +428,7 @@ def do_idea(
client_model,
args.writeup,
args.improvement,
per_experiment_files=args.per_experiment_files
)
print(f"Completed idea: {idea['Name']}, Success: {success}")
except Exception as e:
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ aider-chat
backoff
openai
google-generativeai
google-genai
# Viz
matplotlib
pypdf
Expand Down
25 changes: 25 additions & 0 deletions templates/psycology_survey/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# psychology Survey template

The psychology survey template aims to explore using the AI-scientist
for simulating psychology studies using LLM personas, and then surveying them for analysis.

[LLMs has been used](https://arxiv.org/pdf/2304.03442) to simulate human behaviours. There is potentially value in recreating historical psychology studies using LLM personas given that LLMs encode a large array of human behaviours.

This template does so by templaziting survey based approaches in psychology studies.

To use this template, we need to define 3 components.

1) survey.json. This defines who are we surveying, what questions are we asking, and what analysis is to be done.
2) Personas. These are yaml files that describe virtual personas to be surveyed.
3) analysis.py. This file contains logic that maps analysis types in survey.json to python functions.

As an example, ideas.json is populated with an idea that tries to reproduce the landmark psychology paper [Happiness and unhappiness in the East and West](https://www.researchgate.net/profile/Yukiko-Uchida/publication/26716010_Happiness_and_Unhappiness_in_East_and_West_Themes_and_Variations/links/0c960525e02a7d5940000000/Happiness-and-Unhappiness-in-East-and-West-Themes-and-Variations.pdf). This paper examined the differences in themes of happy and unhappy emotions in Japan and the United states. The template aims to reproduce the paper using LLM personas set in Japan and US.


## Happiness and unhappiness in the east and west example

Run the following to reproduce this experiment. Example paper is included. Note that this experiment doesn't need GPUs to run.

```
python launch_scientist.py --model "gemini-2.5-pro" --experiment psychology_survey --skip-idea-generation --skip-novelty-check --per-experiment-files analyze.py,personas.py,survey.py,data/personas/jp.yaml,data/personas/us.yaml,data/survey.json
```
140 changes: 140 additions & 0 deletions templates/psycology_survey/analyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
"""Helper functions that analysis survey results.

An analysis is a python function that takes SurveyResult as an input,
and either mutates the input or stores an artifact in final_infos.

Example Processor: PointScaleProcessor

How it is used: Specify the following in survey.json:
{
"question_id": "2",
"point_scale": [1,2,3,4,5],
"desc": "Change jp happiness desirability respone to 5 point scale"
}

What it does: It mutates input SurveyResylt by turning answers into point scales.
"""
from typing import Literal, Union
import json
from collections import defaultdict

from survey import SurveyResult
from llm import LLMInput, get_llm_response

import pydantic
from sklearn.manifold import MDS
import numpy as np

ANALYSIS_SYSTEM_INSTRUCTION = """
You are a Phd student that is responible for analyzing survey results.
Try to represent as much opinions as possible in surveys.
"""

class PointScaleProcessor(pydantic.BaseModel):
question_id: str
point_scale: list[int]
desc: str

def process(self, survey_result: SurveyResult, artifacts: dict):
"""Mutates literal answers to point scales"""
for respondent_id, resp in survey_result.responses.items():
for question, answer in resp.items():
if question.id != self.question_id:
continue
if len(list(answer.values())) != len(self.point_scale):
raise Exception("Point scale doesn't match the shape of survey result")

original_choices = question.response_format.sub_type.choices
for k, v in answer.items():
point_scale_val = self.point_scale[original_choices.index(v)]
survey_result.responses[respondent_id][question][k] = point_scale_val

class MergeProcessor(pydantic.BaseModel):
"""Pair responses for analysis purpose.

Example:
question: "What is your favorite dessert?" A: "cake"
merge_with: "On a scale of 1 ~ 5, how much do you like this dessert?" A: 3

merged answer: ("cake", 3)
"""
question_id: str
merge_with: str
desc: str

def process(self, survey_result: SurveyResult, artifacts: dict):
for respondent_id, resp in survey_result.responses.items():
# Question was not tarted for this respondent
if self.question_id not in set({q.id for q in resp.keys()}):
continue
other_resp = None
for question, answer in resp.items():
if question.id == self.merge_with:
other_resp = answer
if not other_resp:
raise Exception("Merge response failed.")
for question, answer in resp.items():
if question.id != self.question_id:
continue
for k, v in answer.items():
merged = [v, other_resp[k]]
survey_result.responses[respondent_id][question][k] = merged


class SimilarityMatrix(pydantic.BaseModel):
matrix: list[list[float]]
responses: list[str]


class SimilarityMatrixProcessor(pydantic.BaseModel):
"""Given a list of statements, use LLM to create a similarity matrix.

The similarity matrix is stored in final_infos, and will be used
for downstream analysis (ex: MDS analysis).
"""
question_id: str
artifact_id: str
processor: Literal["free_form_to_similarity_matrix"]
desc: str

def process(self, survey_result: SurveyResult, final_infos: dict):
all_answers = []
answer_has_multiple_items = False
for respondent_id, resp in survey_result.responses.items():
for question, answer in resp.items():
if question.id != self.question_id:
continue
for k, v in answer.items():
if isinstance(v, list) or isinstance(v, tuple):
# For paired answers, assume first item is free form.
all_answers.append(v[0])
else:
all_answers.append(v)

answers_str = (",").join(all_answers)
prompt = f"""
Given a list of n statements, return a similarity matrix of n x n.
The similarity matrix should be symmetrical, and a greater value
at matrix[i][j] means that statements i and j are more similar.
All values should be between 0.0 and 1.0 inclusive. Before returning
the answer, ensure that the final matrix is symmetric.

example input: ["I like apple", "I like pear", "I hate food"]
example output: [[1.0,0.8,0.1], [0.8,1.0,0.1], [0.1,0.1,1.0]]

input: {answers_str}
"""
llm_input = LLMInput(
prompt=prompt,
max_output_tokens = 50000,
temperature = 0,
response_type = 'application/json',
system_instruction = ANALYSIS_SYSTEM_INSTRUCTION
)
resp = get_llm_response(llm_input)
data = json.loads(resp)
final_infos[self.artifact_id] = \
SimilarityMatrix(matrix=json.loads(resp), responses=all_answers).model_dump()

class Processors(pydantic.BaseModel):
processor: Union[PointScaleProcessor, MergeProcessor, SimilarityMatrixProcessor]
Loading