diff --git a/README.md b/README.md index 7e8b74d8..f9b333b6 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,16 @@ We support a wide variety of models including open-weight and API-only models. I By default, this uses the `OPENAI_API_KEY` environment variable. +##### OpenAI Models via Azure OpenAI Service + +This setup uses the OpenAI API directly, so no additional packages are required. + +Three environment variables are necessary: `AZURE_API_KEY`, `AZURE_API_BASE`, and `AZURE_API_VERSION`. These follow the naming convention specified in [Aider's documentation](https://aider.chat/docs/llms/azure.html). + +When accessing models via Azure OpenAI Service, the deployment name is used instead of the underlying model name in API calls. Ensure that your deployment is named like a standard OpenAI model (e.g., `gpt-4o-2024-08-06`). For more information, check [this link](https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/switching-endpoints). + +If bypassing SSL certificate verification is required, add `--no-verify-ssl` in the argument list when executing your Python code. + #### Anthropic API (Claude Sonnet 3.5) By default, this uses the `ANTHROPIC_API_KEY` environment variable. diff --git a/ai_scientist/generate_ideas.py b/ai_scientist/generate_ideas.py index a8feedfe..99815342 100644 --- a/ai_scientist/generate_ideas.py +++ b/ai_scientist/generate_ideas.py @@ -278,8 +278,16 @@ def on_backoff(details): ) +# @backoff.on_exception( +# backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff +# ) @backoff.on_exception( - backoff.expo, requests.exceptions.HTTPError, on_backoff=on_backoff + backoff.constant, + requests.exceptions.HTTPError, + interval=5, + jitter=backoff.full_jitter, + on_backoff=on_backoff, + # max_tries=10, ) def search_for_papers(query, result_limit=10) -> Union[None, List[Dict]]: if not query: @@ -447,7 +455,6 @@ def check_idea_novelty( if __name__ == "__main__": - MAX_NUM_GENERATIONS = 32 NUM_REFLECTIONS = 5 import argparse @@ -468,9 +475,34 @@ def check_idea_novelty( "gpt-4o-2024-05-13", "deepseek-coder-v2-0724", "llama3.1-405b", + # Anthropic Claude models via Amazon Bedrock + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/anthropic.claude-3-opus-20240229-v1:0", + # Anthropic Claude models Vertex AI + "vertex_ai/claude-3-opus@20240229", + "vertex_ai/claude-3-5-sonnet@20240620", + "vertex_ai/claude-3-sonnet@20240229", + "vertex_ai/claude-3-haiku@20240307", + # OpenAI models via Azure + "azure/gpt-4o-2024-08-06", + "azure/gpt-4o-2024-05-13", ], help="Model to use for AI Scientist.", ) + parser.add_argument( + "--verify-ssl", # implemented only for Azure OpenAI API + action=argparse.BooleanOptionalAction, + default=True, + help="Verify SSL certificate when connecting to model API (default: True)", + ) + parser.add_argument( + "--num-ideas", + type=int, + default=50, + help="Number of ideas to generate", + ) parser.add_argument( "--skip-idea-generation", action="store_true", @@ -482,7 +514,6 @@ def check_idea_novelty( help="Check novelty of ideas.", ) args = parser.parse_args() - # Create client if args.model == "claude-3-5-sonnet-20240620": import anthropic @@ -512,6 +543,20 @@ def check_idea_novelty( print(f"Using OpenAI API with model {args.model}.") client_model = "gpt-4o-2024-05-13" client = openai.OpenAI() + elif args.model.startswith("azure") and "gpt" in args.model: + import openai + if not args.verify_ssl: import httpx + + # Expects: azure/ + client_model = args.model.split("/")[-1] + + print(f"Using Azure with model {client_model}.") + client = openai.AzureOpenAI( + api_key=os.getenv("AZURE_API_KEY"), + api_version=os.getenv("AZURE_API_VERSION"), + azure_endpoint=os.getenv("AZURE_API_BASE"), + http_client = httpx.Client(verify=False) if not args.verify_ssl else None, + ) elif args.model == "deepseek-coder-v2-0724": import openai @@ -539,7 +584,7 @@ def check_idea_novelty( client=client, model=client_model, skip_generation=args.skip_idea_generation, - max_num_generations=MAX_NUM_GENERATIONS, + max_num_generations=args.num_ideas, num_reflections=NUM_REFLECTIONS, ) if args.check_novelty: diff --git a/ai_scientist/perform_writeup.py b/ai_scientist/perform_writeup.py index c32565e9..c7d428c3 100644 --- a/ai_scientist/perform_writeup.py +++ b/ai_scientist/perform_writeup.py @@ -533,16 +533,33 @@ def perform_writeup( "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", "bedrock/anthropic.claude-3-haiku-20240307-v1:0", - "bedrock/anthropic.claude-3-opus-20240229-v1:0" + "bedrock/anthropic.claude-3-opus-20240229-v1:0", # Anthropic Claude models Vertex AI "vertex_ai/claude-3-opus@20240229", "vertex_ai/claude-3-5-sonnet@20240620", "vertex_ai/claude-3-sonnet@20240229", - "vertex_ai/claude-3-haiku@20240307" + "vertex_ai/claude-3-haiku@20240307", + # OpenAI models via Azure + "azure/gpt-4o-2024-08-06", + "azure/gpt-4o-2024-05-13", ], help="Model to use for AI Scientist.", ) + parser.add_argument( + "--verify-ssl", # implemented only for Azure OpenAI API + action=argparse.BooleanOptionalAction, + default=True, + help="Verify SSL certificate when connecting to model API (default: True)", + ) + args = parser.parse_args() + # [--no-verify-ssl] Disable SSL verification for aider's litellm calls + if not args.verify_ssl: + from aider.llm import litellm + import httpx + litellm._load_litellm() + litellm._lazy_module.client_session = httpx.Client(verify=False) + # Create client if args.model == "claude-3-5-sonnet-20240620": import anthropic @@ -571,6 +588,19 @@ def perform_writeup( print(f"Using OpenAI API with model {args.model}.") client_model = "gpt-4o-2024-05-13" client = openai.OpenAI() + elif args.model.startswith("azure") and "gpt" in args.model: + import openai + + # Expects: azure/ + client_model = args.model.split("/")[-1] + + print(f"Using Azure with model {client_model}.") + client = openai.AzureOpenAI( + api_key=os.getenv("AZURE_API_KEY"), + api_version=os.getenv("AZURE_API_VERSION"), + azure_endpoint=os.getenv("AZURE_API_BASE"), + http_client=httpx.Client(verify=False) if not args.verify_ssl else None, + ) elif args.model == "deepseek-coder-v2-0724": import openai diff --git a/launch_scientist.py b/launch_scientist.py index 489c7fc8..33f2981f 100644 --- a/launch_scientist.py +++ b/launch_scientist.py @@ -56,15 +56,25 @@ def parse_arguments(): "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", "bedrock/anthropic.claude-3-haiku-20240307-v1:0", - "bedrock/anthropic.claude-3-opus-20240229-v1:0" + "bedrock/anthropic.claude-3-opus-20240229-v1:0", # Anthropic Claude models Vertex AI "vertex_ai/claude-3-opus@20240229", "vertex_ai/claude-3-5-sonnet@20240620", "vertex_ai/claude-3-sonnet@20240229", "vertex_ai/claude-3-haiku@20240307", + # OpenAI models via Azure + "azure/gpt-4o-2024-08-06", + "azure/gpt-4o-2024-05-13", ], help="Model to use for AI Scientist.", ) + + parser.add_argument( + "--verify-ssl", # implemented only for Azure OpenAI API + action=argparse.BooleanOptionalAction, + default=True, + help="Verify SSL certificate when connecting to model API (default: True)", + ) parser.add_argument( "--writeup", type=str, @@ -310,6 +320,13 @@ def do_idea( print(f"Using GPUs: {available_gpus}") + # [--no-verify-ssl] Disable SSL verification for aider's litellm calls + if not args.verify_ssl: + from aider.llm import litellm + import httpx + litellm._load_litellm() + litellm._lazy_module.client_session = httpx.Client(verify=False) + # Create client if args.model == "claude-3-5-sonnet-20240620": import anthropic @@ -343,6 +360,19 @@ def do_idea( print(f"Using OpenAI API with model {args.model}.") client_model = "gpt-4o-2024-05-13" client = openai.OpenAI() + elif args.model.startswith("azure") and "gpt" in args.model: + import openai + + # Expects: azure/ + client_model = args.model.split("/")[-1] + + print(f"Using Azure with model {client_model}.") + client = openai.AzureOpenAI( + api_key=os.getenv("AZURE_API_KEY"), + api_version=os.getenv("AZURE_API_VERSION"), + azure_endpoint=os.getenv("AZURE_API_BASE"), + http_client=httpx.Client(verify=False) if not args.verify_ssl else None, + ) elif args.model == "deepseek-coder-v2-0724": import openai diff --git a/review_iclr_bench/iclr_analysis.py b/review_iclr_bench/iclr_analysis.py index 40eb031e..d57ae06d 100644 --- a/review_iclr_bench/iclr_analysis.py +++ b/review_iclr_bench/iclr_analysis.py @@ -37,10 +37,30 @@ def parse_arguments(): "llama-3-1-405b-instruct", "deepseek-coder-v2-0724", "claude-3-5-sonnet-20240620", + # Anthropic Claude models via Amazon Bedrock + "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/anthropic.claude-3-opus-20240229-v1:0", + # Anthropic Claude models Vertex AI + "vertex_ai/claude-3-opus@20240229", + "vertex_ai/claude-3-5-sonnet@20240620", + "vertex_ai/claude-3-sonnet@20240229", + "vertex_ai/claude-3-haiku@20240307", + # OpenAI models via Azure + "azure/gpt-4o-2024-08-06", + "azure/gpt-4o-2024-05-13", ], help="Model to use for AI Scientist.", ) + parser.add_argument( + "--verify-ssl", # implemented only for Azure OpenAI API + action=argparse.BooleanOptionalAction, + default=True, + help="Verify SSL certificate when connecting to model API (default: True)", + ) + parser.add_argument( "--num_reviews", type=int, @@ -235,6 +255,7 @@ def review_single_paper( reviewer_system_prompt, review_instruction_form, num_paper_pages, + verify_ssl, ): # Setup client for LLM model if model == "claude-3-5-sonnet-20240620": @@ -266,6 +287,20 @@ def review_single_paper( client = openai.OpenAI( api_key=os.environ["DEEPSEEK_API_KEY"], base_url="https://api.deepseek.com" ) + elif args.model.startswith("azure") and "gpt" in args.model: + import openai + if not verify_ssl: import httpx + + # Expects: azure/ + model = args.model.split("/")[-1] + + print(f"Using Azure with model {model}.") + client = openai.AzureOpenAI( + api_key=os.getenv("AZURE_API_KEY"), + api_version=os.getenv("AZURE_API_VERSION"), + azure_endpoint=os.getenv("AZURE_API_BASE"), + http_client=httpx.Client(verify=False) if not verify_ssl else None, + ) elif model == "llama-3-1-405b-instruct": import openai @@ -339,6 +374,7 @@ def open_review_validate( num_paper_pages=None, data_seed=1, balanced_val=False, + verify_ssl=True, ): ore_ratings = prep_open_review_data( data_seed=data_seed, @@ -387,6 +423,7 @@ def open_review_validate( reviewer_system_prompt, review_instruction_form, num_paper_pages, + verify_ssl, ] ) for _ in range(batch_size): @@ -449,7 +486,7 @@ def open_review_validate( args = parse_arguments() # Create client - float temp as string temperature = str(args.temperature).replace(".", "_") - rating_fname = f"llm_reviews/{args.model}_temp_{temperature}" + rating_fname = f"llm_reviews/{args.model.replace('/','_')}_temp_{temperature}" pathlib.Path("llm_reviews/").mkdir(parents=True, exist_ok=True) if args.num_fs_examples > 0: @@ -485,4 +522,5 @@ def open_review_validate( reviewer_form_prompt, num_paper_pages, balanced_val=False, + verify_ssl=args.verify_ssl, )