diff --git a/lib/ruby_llm.rb b/lib/ruby_llm.rb index 1817ddc45..9c12e3538 100644 --- a/lib/ruby_llm.rb +++ b/lib/ruby_llm.rb @@ -14,6 +14,7 @@ 'ruby_llm' => 'RubyLLM', 'llm' => 'LLM', 'openai' => 'OpenAI', + 'azure_openai' => 'AzureOpenAI', 'api' => 'API', 'deepseek' => 'DeepSeek', 'bedrock' => 'Bedrock', @@ -85,6 +86,7 @@ def logger RubyLLM::Provider.register :openrouter, RubyLLM::Providers::OpenRouter RubyLLM::Provider.register :ollama, RubyLLM::Providers::Ollama RubyLLM::Provider.register :gpustack, RubyLLM::Providers::GPUStack +RubyLLM::Provider.register :azure_openai, RubyLLM::Providers::AzureOpenAI if defined?(Rails::Railtie) require 'ruby_llm/railtie' diff --git a/lib/ruby_llm/configuration.rb b/lib/ruby_llm/configuration.rb index 3348b121d..0075cd412 100644 --- a/lib/ruby_llm/configuration.rb +++ b/lib/ruby_llm/configuration.rb @@ -26,6 +26,10 @@ class Configuration :ollama_api_base, :gpustack_api_base, :gpustack_api_key, + # Azure OpenAI Provider configuration + :azure_openai_api_base, + :azure_openai_api_version, + :azure_openai_api_key, # Default models :default_model, :default_embedding_model, diff --git a/lib/ruby_llm/providers/azure_openai.rb b/lib/ruby_llm/providers/azure_openai.rb new file mode 100644 index 000000000..9d680e41a --- /dev/null +++ b/lib/ruby_llm/providers/azure_openai.rb @@ -0,0 +1,43 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + # Azure OpenAI API integration. Derived from OpenAI integration to support + # OpenAI capabilities via Microsoft Azure endpoints. + module AzureOpenAI + extend OpenAI + extend AzureOpenAI::Chat + extend AzureOpenAI::Streaming + extend AzureOpenAI::Models + + module_function + + def api_base(config) + # https:///openai/deployments//chat/completions?api-version= + "#{config.azure_openai_api_base}/openai" + end + + def headers(config) + { + 'Authorization' => "Bearer #{config.azure_openai_api_key}" + }.compact + end + + def capabilities + OpenAI::Capabilities + end + + def slug + 'azure_openai' + end + + def configuration_requirements + %i[azure_openai_api_key azure_openai_api_base azure_openai_api_version] + end + + def local? + false + end + end + end +end diff --git a/lib/ruby_llm/providers/azure_openai/chat.rb b/lib/ruby_llm/providers/azure_openai/chat.rb new file mode 100644 index 000000000..b1e7f515b --- /dev/null +++ b/lib/ruby_llm/providers/azure_openai/chat.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + module AzureOpenAI + # Chat methods of the Azure OpenAI API integration + module Chat + extend OpenAI::Chat + + module_function + + def sync_response(connection, payload) + # Hold config in instance variable for use in completion_url and stream_url + @config = connection.config + super + end + + def completion_url + # https:///openai/deployments//chat/completions?api-version= + "deployments/#{@model_id}/chat/completions?api-version=#{@config.azure_openai_api_version}" + end + + def render_payload(messages, tools:, temperature:, model:, stream: false) + # Hold model_id in instance variable for use in completion_url and stream_url + @model_id = model + super + end + end + end + end +end diff --git a/lib/ruby_llm/providers/azure_openai/models.rb b/lib/ruby_llm/providers/azure_openai/models.rb new file mode 100644 index 000000000..18573c170 --- /dev/null +++ b/lib/ruby_llm/providers/azure_openai/models.rb @@ -0,0 +1,33 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + module AzureOpenAI + # Models methods of the OpenAI API integration + module Models + extend OpenAI::Models + + KNOWN_MODELS = [ + 'gpt-4o' + ].freeze + + module_function + + def models_url + 'models?api-version=2024-10-21' + end + + def parse_list_models_response(response, slug, capabilities) + # select the known models only since this list from Azure OpenAI is + # very long + response.body['data'].select! do |m| + KNOWN_MODELS.include?(m['id']) + end + # Use the OpenAI processor for the list, keeping in mind that pricing etc + # won't be correct + super + end + end + end + end +end diff --git a/lib/ruby_llm/providers/azure_openai/streaming.rb b/lib/ruby_llm/providers/azure_openai/streaming.rb new file mode 100644 index 000000000..139ee2578 --- /dev/null +++ b/lib/ruby_llm/providers/azure_openai/streaming.rb @@ -0,0 +1,20 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + module AzureOpenAI + # Streaming methods of the Azure OpenAI API integration + module Streaming + extend OpenAI::Streaming + + module_function + + def stream_response(connection, payload, &) + # Hold config in instance variable for use in completion_url and stream_url + @config = connection.config + super + end + end + end + end +end diff --git a/lib/tasks/models_update.rake b/lib/tasks/models_update.rake index c9f724c21..5a8797704 100644 --- a/lib/tasks/models_update.rake +++ b/lib/tasks/models_update.rake @@ -24,10 +24,17 @@ def configure_from_env config.deepseek_api_key = ENV.fetch('DEEPSEEK_API_KEY', nil) config.openrouter_api_key = ENV.fetch('OPENROUTER_API_KEY', nil) configure_bedrock(config) + configure_azure_openai(config) config.request_timeout = 30 end end +def configure_azure_openai(config) + config.azure_openai_api_base = ENV.fetch('AZURE_OPENAI_ENDPOINT', nil) + config.azure_openai_api_key = ENV.fetch('AZURE_OPENAI_API_KEY', nil) + config.azure_openai_api_version = ENV.fetch('AZURE_OPENAI_API_VER', nil) +end + def configure_bedrock(config) config.bedrock_api_key = ENV.fetch('AWS_ACCESS_KEY_ID', nil) config.bedrock_secret_key = ENV.fetch('AWS_SECRET_ACCESS_KEY', nil)