diff --git a/ai_scientist/llm.py b/ai_scientist/llm.py index 27c9eee8..40a3dd42 100644 --- a/ai_scientist/llm.py +++ b/ai_scientist/llm.py @@ -5,8 +5,8 @@ import anthropic import backoff import openai -import google.generativeai as genai -from google.generativeai.types import GenerationConfig +import google.genai as genai +from google.genai.types import GenerateContentConfig MAX_NUM_TOKENS = 4096 @@ -258,19 +258,19 @@ def get_response_from_llm( content = response.choices[0].message.content new_msg_history = new_msg_history + [{"role": "assistant", "content": content}] elif "gemini" in model: - new_msg_history = msg_history + [{"role": "user", "content": msg}] - response = client.chat.completions.create( - model=model, - messages=[ - {"role": "system", "content": system_message}, - *new_msg_history, - ], - temperature=temperature, - max_tokens=MAX_NUM_TOKENS, - n=1, + # The `google-genai` library manages history differently. + chat = client.chats.create(model=model, history=msg_history) + response = chat.send_message( + msg, + config=GenerateContentConfig( + system_instruction=system_message, + max_output_tokens=MAX_NUM_TOKENS, + temperature=temperature, + ) ) - content = response.choices[0].message.content - new_msg_history = new_msg_history + [{"role": "assistant", "content": content}] + content = response.text + # The new history is the chat object's history + new_msg_history = chat.get_history() else: raise ValueError(f"Model {model} not supported.") @@ -342,10 +342,9 @@ def create_client(model): base_url="https://openrouter.ai/api/v1" ), "meta-llama/llama-3.1-405b-instruct" elif "gemini" in model: - print(f"Using OpenAI API with {model}.") - return openai.OpenAI( + print(f"Using Google GenAI API with {model}.") + return genai.client.Client( api_key=os.environ["GEMINI_API_KEY"], - base_url="https://generativelanguage.googleapis.com/v1beta/openai/" ), model else: raise ValueError(f"Model {model} not supported.") diff --git a/launch_scientist.py b/launch_scientist.py index 2fe7a49c..e601c536 100644 --- a/launch_scientist.py +++ b/launch_scientist.py @@ -204,6 +204,8 @@ def do_idea( main_model = Model("deepseek/deepseek-reasoner") elif model == "llama3.1-405b": main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct") + elif "gemini" in model: + main_model = Model(f"gemini/{model}") else: main_model = Model(model) coder = Coder.create( @@ -240,6 +242,8 @@ def do_idea( main_model = Model("deepseek/deepseek-reasoner") elif model == "llama3.1-405b": main_model = Model("openrouter/meta-llama/llama-3.1-405b-instruct") + elif "gemini" in model: + main_model = Model(f"gemini/{model}") else: main_model = Model(model) coder = Coder.create( @@ -360,7 +364,7 @@ def do_idea( with open(osp.join(base_dir, "ideas.json"), "w") as f: json.dump(ideas, f, indent=4) - novel_ideas = [idea for idea in ideas if idea["novel"]] + novel_ideas = ideas if args.skip_novelty_check else [idea for idea in ideas if idea["novel"]] # novel_ideas = list(reversed(novel_ideas)) if args.parallel > 0: diff --git a/requirements.txt b/requirements.txt index 8971848d..f42f61ea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,7 +3,9 @@ anthropic aider-chat backoff openai +# aider-chat still uses the old SDK google-generativeai +google-genai # Viz matplotlib pypdf