From f58a34788cf192087db0d87f8eea2f6d2eaa3559 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 22 Oct 2025 22:00:59 -0700 Subject: [PATCH] independent test cases --- core/providers/anthropic_test.go | 248 ++++++++++++++++ core/providers/azure_test.go | 469 +++++++++++++++++++++++++++++ core/providers/bedrock_test.go | 391 +++++++++++++++++++++++++ core/providers/cohere_test.go | 336 +++++++++++++++++++++ core/providers/gemini_test.go | 244 ++++++++++++++++ core/providers/groq_test.go | 257 ++++++++++++++++ core/providers/openai_test.go | 285 ++++++++++++++++++ core/providers/openrouter_test.go | 202 +++++++++++++ core/providers/vertex_test.go | 470 ++++++++++++++++++++++++++++++ 9 files changed, 2902 insertions(+) create mode 100644 core/providers/anthropic_test.go create mode 100644 core/providers/azure_test.go create mode 100644 core/providers/bedrock_test.go create mode 100644 core/providers/cohere_test.go create mode 100644 core/providers/gemini_test.go create mode 100644 core/providers/groq_test.go create mode 100644 core/providers/openai_test.go create mode 100644 core/providers/openrouter_test.go create mode 100644 core/providers/vertex_test.go diff --git a/core/providers/anthropic_test.go b/core/providers/anthropic_test.go new file mode 100644 index 000000000..d19289db1 --- /dev/null +++ b/core/providers/anthropic_test.go @@ -0,0 +1,248 @@ +package providers + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestAnthropicChatCompletion(t *testing.T) { + apiKey := os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY not set") + } + + provider := NewAnthropicProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.anthropic.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-3-haiku-20240307", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, err := provider.ChatCompletion(ctx, key, request) + if err != nil { + t.Fatalf("ChatCompletion failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestAnthropicChatCompletionWithTools(t *testing.T) { + apiKey := os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY not set") + } + + provider := NewAnthropicProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.anthropic.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-3-haiku-20240307", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, err := provider.ChatCompletion(ctx, key, request) + if err != nil { + t.Fatalf("ChatCompletion with tools failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestAnthropicChatCompletionStream(t *testing.T) { + apiKey := os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY not set") + } + + provider := NewAnthropicProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.anthropic.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Anthropic, + Model: "claude-3-haiku-20240307", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 3")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, err := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if err != nil { + t.Fatalf("ChatCompletionStream failed: %v", err) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestAnthropicResponses(t *testing.T) { + apiKey := os.Getenv("ANTHROPIC_API_KEY") + if apiKey == "" { + t.Skip("ANTHROPIC_API_KEY not set") + } + + provider := NewAnthropicProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.anthropic.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + userRole := schemas.ResponsesInputMessageRoleUser + request := &schemas.BifrostResponsesRequest{ + Provider: schemas.Anthropic, + Model: "claude-3-5-sonnet-20241022", + Input: []schemas.ResponsesMessage{ + { + Role: &userRole, + Content: &schemas.ResponsesMessageContent{ContentStr: stringPtr("What is 2+2?")}, + }, + }, + Params: &schemas.ResponsesParameters{ + MaxOutputTokens: intPtr(100), + }, + } + + resp, err := provider.Responses(ctx, key, request) + if err != nil { + t.Fatalf("Responses failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Output) == 0 { + t.Fatal("Expected at least one output message") + } + t.Logf("Response output messages: %d", len(resp.Output)) +} + +func TestAnthropicGetProviderKey(t *testing.T) { + provider := NewAnthropicProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.anthropic.com", + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + key := provider.GetProviderKey() + if key != schemas.Anthropic { + t.Errorf("Expected provider key %s, got %s", schemas.Anthropic, key) + } +} diff --git a/core/providers/azure_test.go b/core/providers/azure_test.go new file mode 100644 index 000000000..3ae19c01a --- /dev/null +++ b/core/providers/azure_test.go @@ -0,0 +1,469 @@ +package providers + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestAzureChatCompletion(t *testing.T) { + apiKey := os.Getenv("AZURE_API_KEY") + endpoint := os.Getenv("AZURE_ENDPOINT") + if apiKey == "" || endpoint == "" { + t.Skip("AZURE_API_KEY or AZURE_ENDPOINT not set") + } + + provider, err := NewAzureProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Azure provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + Value: apiKey, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: endpoint, + Deployments: map[string]string{ + "gpt-4o-mini": "gpt-4o-mini", + }, + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Azure, + Model: "gpt-4o-mini", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestAzureChatCompletionWithSingleTool(t *testing.T) { + apiKey := os.Getenv("AZURE_API_KEY") + endpoint := os.Getenv("AZURE_ENDPOINT") + if apiKey == "" || endpoint == "" { + t.Skip("AZURE_API_KEY or AZURE_ENDPOINT not set") + } + + provider, err := NewAzureProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Azure provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + Value: apiKey, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: endpoint, + Deployments: map[string]string{ + "gpt-4o-mini": "gpt-4o-mini", + }, + }, + } + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Azure, + Model: "gpt-4o-mini", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with single tool failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestAzureChatCompletionWithMultipleTools(t *testing.T) { + apiKey := os.Getenv("AZURE_API_KEY") + endpoint := os.Getenv("AZURE_ENDPOINT") + if apiKey == "" || endpoint == "" { + t.Skip("AZURE_API_KEY or AZURE_ENDPOINT not set") + } + + provider, err := NewAzureProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Azure provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + Value: apiKey, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: endpoint, + Deployments: map[string]string{ + "gpt-4o-mini": "gpt-4o-mini", + }, + }, + } + + weatherProps := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + timeProps := map[string]interface{}{ + "timezone": map[string]interface{}{ + "type": "string", + "description": "The timezone identifier", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Azure, + Model: "gpt-4o-mini", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather and current time in New York?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(200), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &weatherProps, + Required: []string{"location"}, + }, + }, + }, + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_time", + Description: stringPtr("Get the current time"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &timeProps, + Required: []string{"timezone"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with multiple tools failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestAzureChatCompletionWithParallelToolCalls(t *testing.T) { + apiKey := os.Getenv("AZURE_API_KEY") + endpoint := os.Getenv("AZURE_ENDPOINT") + if apiKey == "" || endpoint == "" { + t.Skip("AZURE_API_KEY or AZURE_ENDPOINT not set") + } + + provider, err := NewAzureProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Azure provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + Value: apiKey, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: endpoint, + Deployments: map[string]string{ + "gpt-4o-mini": "gpt-4o-mini", + }, + }, + } + + weatherProps := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Azure, + Model: "gpt-4o-mini", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco, New York, and London?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(200), + ParallelToolCalls: boolPtr(true), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &weatherProps, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with parallel tool calls failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Parallel tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestAzureChatCompletionStream(t *testing.T) { + apiKey := os.Getenv("AZURE_API_KEY") + endpoint := os.Getenv("AZURE_ENDPOINT") + if apiKey == "" || endpoint == "" { + t.Skip("AZURE_API_KEY or AZURE_ENDPOINT not set") + } + + provider, err := NewAzureProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Azure provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + Value: apiKey, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: endpoint, + Deployments: map[string]string{ + "gpt-4o-mini": "gpt-4o-mini", + }, + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Azure, + Model: "gpt-4o-mini", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 5")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, bfErr := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletionStream failed: %v", bfErr) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestAzureEmbedding(t *testing.T) { + apiKey := os.Getenv("AZURE_EMB_API_KEY") + endpoint := os.Getenv("AZURE_EMB_ENDPOINT") + if apiKey == "" || endpoint == "" { + t.Skip("AZURE_EMB_API_KEY or AZURE_EMB_ENDPOINT not set") + } + + provider, err := NewAzureProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Azure provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + Value: apiKey, + AzureKeyConfig: &schemas.AzureKeyConfig{ + Endpoint: endpoint, + Deployments: map[string]string{ + "text-embedding-ada-002": "text-embedding-ada-002", + }, + }, + } + + request := &schemas.BifrostEmbeddingRequest{ + Provider: schemas.Azure, + Model: "text-embedding-ada-002", + Input: &schemas.EmbeddingInput{Texts: []string{"Hello world"}}, + } + + resp, bfErr := provider.Embedding(ctx, key, request) + if bfErr != nil { + t.Fatalf("Embedding failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Data) == 0 { + t.Fatal("Expected at least one embedding") + } + if len(resp.Data[0].Embedding.EmbeddingArray) == 0 { + t.Fatal("Expected non-empty embedding vector") + } + t.Logf("Embedding dimension: %d", len(resp.Data[0].Embedding.EmbeddingArray)) +} + +func TestAzureGetProviderKey(t *testing.T) { + provider, err := NewAzureProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Azure provider: %v", err) + } + + key := provider.GetProviderKey() + if key != schemas.Azure { + t.Errorf("Expected provider key %s, got %s", schemas.Azure, key) + } +} diff --git a/core/providers/bedrock_test.go b/core/providers/bedrock_test.go new file mode 100644 index 000000000..8c84659ef --- /dev/null +++ b/core/providers/bedrock_test.go @@ -0,0 +1,391 @@ +package providers + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestBedrockChatCompletion(t *testing.T) { + apiKey := os.Getenv("BEDROCK_API_KEY") + if apiKey == "" { + t.Fatal("BEDROCK_API_KEY not set") + } + + provider, err := NewBedrockProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Bedrock provider: %v", err) + } + + ctx := context.Background() + region := "us-east-1" + key := schemas.Key{ + Value: apiKey, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + Region: ®ion, + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Bedrock, + Model: "anthropic.claude-3-haiku-20240307-v1:0", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + if bfErr.Error != nil { + if bfErr.Error.Error != nil { + t.Fatalf("ChatCompletion failed - error: %v", bfErr.Error.Error) + } + t.Fatalf("ChatCompletion failed - message: %s", bfErr.Error.Message) + } + t.Fatalf("ChatCompletion failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestBedrockChatCompletionWithSingleTool(t *testing.T) { + apiKey := os.Getenv("BEDROCK_API_KEY") + if apiKey == "" { + t.Fatal("BEDROCK_API_KEY not set") + } + + provider, err := NewBedrockProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Bedrock provider: %v", err) + } + + ctx := context.Background() + region := "us-east-1" + key := schemas.Key{ + Value: apiKey, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + Region: ®ion, + }, + } + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Bedrock, + Model: "anthropic.claude-3-haiku-20240307-v1:0", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with single tool failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestBedrockChatCompletionWithMultipleTools(t *testing.T) { + apiKey := os.Getenv("BEDROCK_API_KEY") + if apiKey == "" { + t.Fatal("BEDROCK_API_KEY not set") + } + + provider, err := NewBedrockProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Bedrock provider: %v", err) + } + + ctx := context.Background() + region := "us-east-1" + key := schemas.Key{ + Value: apiKey, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + Region: ®ion, + }, + } + + weatherProps := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + timeProps := map[string]interface{}{ + "timezone": map[string]interface{}{ + "type": "string", + "description": "The timezone identifier", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Bedrock, + Model: "anthropic.claude-3-haiku-20240307-v1:0", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather and current time in New York?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(200), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &weatherProps, + Required: []string{"location"}, + }, + }, + }, + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_time", + Description: stringPtr("Get the current time"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &timeProps, + Required: []string{"timezone"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with multiple tools failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestBedrockChatCompletionStream(t *testing.T) { + apiKey := os.Getenv("BEDROCK_API_KEY") + if apiKey == "" { + t.Fatal("BEDROCK_API_KEY not set") + } + + provider, err := NewBedrockProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Bedrock provider: %v", err) + } + + ctx := context.Background() + region := "us-east-1" + key := schemas.Key{ + Value: apiKey, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + Region: ®ion, + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Bedrock, + Model: "anthropic.claude-3-haiku-20240307-v1:0", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 5")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, bfErr := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletionStream failed: %v", bfErr) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestBedrockResponses(t *testing.T) { + apiKey := os.Getenv("BEDROCK_API_KEY") + if apiKey == "" { + t.Fatal("BEDROCK_API_KEY not set") + } + + provider, err := NewBedrockProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Bedrock provider: %v", err) + } + + ctx := context.Background() + region := "us-east-1" + key := schemas.Key{ + Value: apiKey, + BedrockKeyConfig: &schemas.BedrockKeyConfig{ + Region: ®ion, + }, + } + + userRole := schemas.ResponsesInputMessageRoleUser + request := &schemas.BifrostResponsesRequest{ + Provider: schemas.Bedrock, + Model: "anthropic.claude-3-haiku-20240307-v1:0", + Input: []schemas.ResponsesMessage{ + { + Role: &userRole, + Content: &schemas.ResponsesMessageContent{ContentStr: stringPtr("What is 2+2?")}, + }, + }, + Params: &schemas.ResponsesParameters{ + MaxOutputTokens: intPtr(100), + }, + } + + resp, bfErr := provider.Responses(ctx, key, request) + if bfErr != nil { + if bfErr.Error != nil { + if bfErr.Error.Error != nil { + t.Fatalf("Responses failed - error: %v", bfErr.Error.Error) + } + t.Fatalf("Responses failed - message: %s", bfErr.Error.Message) + } + t.Fatalf("Responses failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Output) == 0 { + t.Fatal("Expected at least one output message") + } + t.Logf("Response output messages: %d", len(resp.Output)) +} + +func TestBedrockGetProviderKey(t *testing.T) { + provider, err := NewBedrockProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Bedrock provider: %v", err) + } + + key := provider.GetProviderKey() + if key != schemas.Bedrock { + t.Errorf("Expected provider key %s, got %s", schemas.Bedrock, key) + } +} diff --git a/core/providers/cohere_test.go b/core/providers/cohere_test.go new file mode 100644 index 000000000..b5cffa3c8 --- /dev/null +++ b/core/providers/cohere_test.go @@ -0,0 +1,336 @@ +package providers + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestCohereChatCompletion(t *testing.T) { + apiKey := os.Getenv("COHERE_API_KEY") + if apiKey == "" { + t.Skip("COHERE_API_KEY not set") + } + + provider := NewCohereProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.cohere.com", + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Cohere, + Model: "command-r-08-2024", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + if bfErr.Error != nil { + if bfErr.Error.Error != nil { + t.Fatalf("ChatCompletion failed - error: %v", bfErr.Error.Error) + } + t.Fatalf("ChatCompletion failed - message: %s", bfErr.Error.Message) + } + t.Fatalf("ChatCompletion failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestCohereChatCompletionWithSingleTool(t *testing.T) { + apiKey := os.Getenv("COHERE_API_KEY") + if apiKey == "" { + t.Skip("COHERE_API_KEY not set") + } + + provider := NewCohereProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.cohere.com", + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Cohere, + Model: "command-r-08-2024", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with single tool failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestCohereChatCompletionWithMultipleTools(t *testing.T) { + apiKey := os.Getenv("COHERE_API_KEY") + if apiKey == "" { + t.Skip("COHERE_API_KEY not set") + } + + provider := NewCohereProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.cohere.com", + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + weatherProps := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + timeProps := map[string]interface{}{ + "timezone": map[string]interface{}{ + "type": "string", + "description": "The timezone identifier", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Cohere, + Model: "command-r-08-2024", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather and current time in New York?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(200), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &weatherProps, + Required: []string{"location"}, + }, + }, + }, + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_time", + Description: stringPtr("Get the current time"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &timeProps, + Required: []string{"timezone"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with multiple tools failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestCohereChatCompletionStream(t *testing.T) { + apiKey := os.Getenv("COHERE_API_KEY") + if apiKey == "" { + t.Skip("COHERE_API_KEY not set") + } + + provider := NewCohereProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.cohere.com", + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Cohere, + Model: "command-r-08-2024", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 5")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, bfErr := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletionStream failed: %v", bfErr) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestCohereEmbedding(t *testing.T) { + apiKey := os.Getenv("COHERE_API_KEY") + if apiKey == "" { + t.Skip("COHERE_API_KEY not set") + } + + provider := NewCohereProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.cohere.com", + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostEmbeddingRequest{ + Provider: schemas.Cohere, + Model: "embed-english-v3.0", + Input: &schemas.EmbeddingInput{Texts: []string{"Hello world"}}, + } + + resp, bfErr := provider.Embedding(ctx, key, request) + if bfErr != nil { + t.Fatalf("Embedding failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Data) == 0 { + t.Fatal("Expected at least one embedding") + } + if len(resp.Data[0].Embedding.EmbeddingArray) == 0 { + t.Fatal("Expected non-empty embedding vector") + } + t.Logf("Embedding dimension: %d", len(resp.Data[0].Embedding.EmbeddingArray)) +} + +func TestCohereGetProviderKey(t *testing.T) { + provider := NewCohereProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.cohere.com", + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + key := provider.GetProviderKey() + if key != schemas.Cohere { + t.Errorf("Expected provider key %s, got %s", schemas.Cohere, key) + } +} diff --git a/core/providers/gemini_test.go b/core/providers/gemini_test.go new file mode 100644 index 000000000..65b733a37 --- /dev/null +++ b/core/providers/gemini_test.go @@ -0,0 +1,244 @@ +package providers + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestGeminiChatCompletion(t *testing.T) { + apiKey := os.Getenv("GEMINI_API_KEY") + if apiKey == "" { + t.Fatal("GEMINI_API_KEY not set") + } + + provider := NewGeminiProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Gemini, + Model: "gemini-pro", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + if bfErr.Error != nil { + if bfErr.Error.Error != nil { + t.Fatalf("ChatCompletion failed - error: %v", bfErr.Error.Error) + } + t.Fatalf("ChatCompletion failed - message: %s", bfErr.Error.Message) + } + t.Fatalf("ChatCompletion failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestGeminiChatCompletionWithTools(t *testing.T) { + apiKey := os.Getenv("GEMINI_API_KEY") + if apiKey == "" { + t.Fatal("GEMINI_API_KEY not set") + } + + provider := NewGeminiProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Gemini, + Model: "gemini-pro", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with tools failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestGeminiChatCompletionStream(t *testing.T) { + apiKey := os.Getenv("GEMINI_API_KEY") + if apiKey == "" { + t.Fatal("GEMINI_API_KEY not set") + } + + provider := NewGeminiProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Gemini, + Model: "gemini-pro", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 3")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, bfErr := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletionStream failed: %v", bfErr) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestGeminiEmbedding(t *testing.T) { + apiKey := os.Getenv("GEMINI_API_KEY") + if apiKey == "" { + t.Fatal("GEMINI_API_KEY not set") + } + + provider := NewGeminiProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostEmbeddingRequest{ + Provider: schemas.Gemini, + Model: "text-embedding-004", + Input: &schemas.EmbeddingInput{Texts: []string{"Hello world"}}, + } + + resp, bfErr := provider.Embedding(ctx, key, request) + if bfErr != nil { + t.Fatalf("Embedding failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Data) == 0 { + t.Fatal("Expected at least one embedding") + } + if len(resp.Data[0].Embedding.EmbeddingArray) == 0 { + t.Fatal("Expected non-empty embedding vector") + } + t.Logf("Embedding dimension: %d", len(resp.Data[0].Embedding.EmbeddingArray)) +} + +func TestGeminiGetProviderKey(t *testing.T) { + provider := NewGeminiProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://generativelanguage.googleapis.com", + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + key := provider.GetProviderKey() + if key != schemas.Gemini { + t.Errorf("Expected provider key %s, got %s", schemas.Gemini, key) + } +} diff --git a/core/providers/groq_test.go b/core/providers/groq_test.go new file mode 100644 index 000000000..ad621923e --- /dev/null +++ b/core/providers/groq_test.go @@ -0,0 +1,257 @@ +package providers + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestGroqChatCompletion(t *testing.T) { + apiKey := os.Getenv("GROQ_API_KEY") + if apiKey == "" { + t.Skip("GROQ_API_KEY not set") + } + + provider, err := NewGroqProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.groq.com/openai", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Groq provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Groq, + Model: "llama-3.3-70b-versatile", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestGroqChatCompletionWithTools(t *testing.T) { + apiKey := os.Getenv("GROQ_API_KEY") + if apiKey == "" { + t.Skip("GROQ_API_KEY not set") + } + + provider, err := NewGroqProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.groq.com/openai", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Groq provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Groq, + Model: "llama-3.3-70b-versatile", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with tools failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestGroqChatCompletionStream(t *testing.T) { + apiKey := os.Getenv("GROQ_API_KEY") + if apiKey == "" { + t.Skip("GROQ_API_KEY not set") + } + + provider, err := NewGroqProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.groq.com/openai", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Groq provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Groq, + Model: "llama-3.3-70b-versatile", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 3")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, bfErr := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletionStream failed: %v", bfErr) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestGroqTranscription(t *testing.T) { + apiKey := os.Getenv("GROQ_API_KEY") + if apiKey == "" { + t.Skip("GROQ_API_KEY not set") + } + + provider, err := NewGroqProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.groq.com/openai", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Groq provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + // Note: This test requires an actual audio file. Skipping for now. + // In a real test, you would provide a base64-encoded audio file. + t.Skip("Transcription test requires audio file - implement when needed") + + request := &schemas.BifrostTranscriptionRequest{ + Provider: schemas.Groq, + Model: "whisper-large-v3", + Input: &schemas.TranscriptionInput{ + // Would need actual audio data here + }, + } + + resp, bfErr := provider.Transcription(ctx, key, request) + if bfErr != nil { + t.Fatalf("Transcription failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + t.Logf("Transcription: %v", resp.Text) +} + +func TestGroqGetProviderKey(t *testing.T) { + provider, err := NewGroqProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.groq.com/openai", + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Groq provider: %v", err) + } + + key := provider.GetProviderKey() + if key != schemas.Groq { + t.Errorf("Expected provider key %s, got %s", schemas.Groq, key) + } +} diff --git a/core/providers/openai_test.go b/core/providers/openai_test.go new file mode 100644 index 000000000..cdfc080ee --- /dev/null +++ b/core/providers/openai_test.go @@ -0,0 +1,285 @@ +package providers + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOpenAIChatCompletion(t *testing.T) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } + + provider := NewOpenAIProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.openai.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-3.5-turbo", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, err := provider.ChatCompletion(ctx, key, request) + if err != nil { + t.Fatalf("ChatCompletion failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestOpenAIChatCompletionWithTools(t *testing.T) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } + + provider := NewOpenAIProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.openai.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-3.5-turbo", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, err := provider.ChatCompletion(ctx, key, request) + if err != nil { + t.Fatalf("ChatCompletion with tools failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestOpenAIChatCompletionStream(t *testing.T) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } + + provider := NewOpenAIProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.openai.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.OpenAI, + Model: "gpt-3.5-turbo", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 3")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, err := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if err != nil { + t.Fatalf("ChatCompletionStream failed: %v", err) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestOpenAITextCompletion(t *testing.T) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } + + provider := NewOpenAIProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.openai.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostTextCompletionRequest{ + Provider: schemas.OpenAI, + Model: "gpt-3.5-turbo-instruct", + Input: &schemas.TextCompletionInput{PromptStr: stringPtr("Say hello")}, + Params: &schemas.TextCompletionParameters{ + Temperature: float64Ptr(0.7), + MaxTokens: intPtr(10), + }, + } + + resp, err := provider.TextCompletion(ctx, key, request) + if err != nil { + t.Fatalf("TextCompletion failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Response: %v", resp.Choices[0].Text) +} + +func TestOpenAIEmbedding(t *testing.T) { + apiKey := os.Getenv("OPENAI_API_KEY") + if apiKey == "" { + t.Skip("OPENAI_API_KEY not set") + } + + provider := NewOpenAIProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.openai.com", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostEmbeddingRequest{ + Provider: schemas.OpenAI, + Model: "text-embedding-ada-002", + Input: &schemas.EmbeddingInput{Texts: []string{"Hello world"}}, + } + + resp, err := provider.Embedding(ctx, key, request) + if err != nil { + t.Fatalf("Embedding failed: %v", err) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Data) == 0 { + t.Fatal("Expected at least one embedding") + } + if len(resp.Data[0].Embedding.EmbeddingArray) == 0 { + t.Fatal("Expected non-empty embedding vector") + } + t.Logf("Embedding dimension: %d", len(resp.Data[0].Embedding.EmbeddingArray)) +} + +func TestOpenAIGetProviderKey(t *testing.T) { + provider := NewOpenAIProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://api.openai.com", + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + key := provider.GetProviderKey() + if key != schemas.OpenAI { + t.Errorf("Expected provider key %s, got %s", schemas.OpenAI, key) + } +} diff --git a/core/providers/openrouter_test.go b/core/providers/openrouter_test.go new file mode 100644 index 000000000..8d69c5976 --- /dev/null +++ b/core/providers/openrouter_test.go @@ -0,0 +1,202 @@ +package providers + +import ( + "context" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestOpenRouterChatCompletion(t *testing.T) { + apiKey := os.Getenv("OPENROUTER_API_KEY") + if apiKey == "" { + t.Skip("OPENROUTER_API_KEY not set") + } + + provider := NewOpenRouterProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://openrouter.ai/api", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.OpenRouter, + Model: "meta-llama/llama-3.2-3b-instruct:free", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestOpenRouterChatCompletionWithTools(t *testing.T) { + t.Skip("Free model doesn't support tools") + + apiKey := os.Getenv("OPENROUTER_API_KEY") + if apiKey == "" { + t.Skip("OPENROUTER_API_KEY not set") + } + + provider := NewOpenRouterProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://openrouter.ai/api", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.OpenRouter, + Model: "meta-llama/llama-3.2-3b-instruct:free", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with tools failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestOpenRouterChatCompletionStream(t *testing.T) { + apiKey := os.Getenv("OPENROUTER_API_KEY") + if apiKey == "" { + t.Skip("OPENROUTER_API_KEY not set") + } + + provider := NewOpenRouterProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://openrouter.ai/api", + DefaultRequestTimeoutInSeconds: 30, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + ctx := context.Background() + key := schemas.Key{Value: apiKey} + + request := &schemas.BifrostChatRequest{ + Provider: schemas.OpenRouter, + Model: "meta-llama/llama-3.2-3b-instruct:free", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 3")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, bfErr := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletionStream failed: %v", bfErr) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestOpenRouterGetProviderKey(t *testing.T) { + provider := NewOpenRouterProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + BaseURL: "https://openrouter.ai/api", + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + + key := provider.GetProviderKey() + if key != schemas.OpenRouter { + t.Errorf("Expected provider key %s, got %s", schemas.OpenRouter, key) + } +} diff --git a/core/providers/vertex_test.go b/core/providers/vertex_test.go new file mode 100644 index 000000000..2531e3083 --- /dev/null +++ b/core/providers/vertex_test.go @@ -0,0 +1,470 @@ +package providers + +import ( + "context" + "fmt" + "os" + "testing" + + "github.com/maximhq/bifrost/core/schemas" +) + +func TestVertexChatCompletion(t *testing.T) { + credentials := os.Getenv("VERTEX_CREDENTIALS") + projectID := os.Getenv("VERTEX_PROJECT_ID") + if credentials == "" || projectID == "" { + t.Fatal("VERTEX_CREDENTIALS or VERTEX_PROJECT_ID not set") + } + + provider, err := NewVertexProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Vertex provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: projectID, + Region: "us-central1", + AuthCredentials: credentials, + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Vertex, + Model: "claude-3-5-haiku", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Say hello in one word")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(10), + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + if bfErr.Error != nil { + errorMsg := bfErr.Error.Message + if bfErr.Error.Error != nil { + errorMsg += fmt.Sprintf(" | Details: %v", bfErr.Error.Error) + } + t.Fatalf("ChatCompletion failed: %s", errorMsg) + } + t.Fatalf("ChatCompletion failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + if resp.Choices[0].Message.Content == nil { + t.Fatal("Expected message content") + } + t.Logf("Response: %v", resp.Choices[0].Message.Content) +} + +func TestVertexChatCompletionWithSingleTool(t *testing.T) { + credentials := os.Getenv("VERTEX_CREDENTIALS") + projectID := os.Getenv("VERTEX_PROJECT_ID") + if credentials == "" || projectID == "" { + t.Fatal("VERTEX_CREDENTIALS or VERTEX_PROJECT_ID not set") + } + + provider, err := NewVertexProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Vertex provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: projectID, + Region: "us-central1", + AuthCredentials: credentials, + }, + } + + props := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Vertex, + Model: "claude-3-5-haiku", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(100), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &props, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with single tool failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestVertexChatCompletionWithMultipleTools(t *testing.T) { + credentials := os.Getenv("VERTEX_CREDENTIALS") + projectID := os.Getenv("VERTEX_PROJECT_ID") + if credentials == "" || projectID == "" { + t.Fatal("VERTEX_CREDENTIALS or VERTEX_PROJECT_ID not set") + } + + provider, err := NewVertexProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Vertex provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: projectID, + Region: "us-central1", + AuthCredentials: credentials, + }, + } + + weatherProps := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + timeProps := map[string]interface{}{ + "timezone": map[string]interface{}{ + "type": "string", + "description": "The timezone identifier", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Vertex, + Model: "claude-3-7-sonnet@20240229", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather and current time in New York?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(200), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &weatherProps, + Required: []string{"location"}, + }, + }, + }, + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_time", + Description: stringPtr("Get the current time"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &timeProps, + Required: []string{"timezone"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with multiple tools failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestVertexChatCompletionWithParallelToolCalls(t *testing.T) { + credentials := os.Getenv("VERTEX_CREDENTIALS") + projectID := os.Getenv("VERTEX_PROJECT_ID") + if credentials == "" || projectID == "" { + t.Fatal("VERTEX_CREDENTIALS or VERTEX_PROJECT_ID not set") + } + + provider, err := NewVertexProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Vertex provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: projectID, + Region: "us-central1", + AuthCredentials: credentials, + }, + } + + weatherProps := map[string]interface{}{ + "location": map[string]interface{}{ + "type": "string", + "description": "The city name", + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Vertex, + Model: "claude-3-5-haiku", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("What's the weather in San Francisco, New York, and London?")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + MaxCompletionTokens: intPtr(200), + ParallelToolCalls: boolPtr(true), + Tools: []schemas.ChatTool{ + { + Type: schemas.ChatToolTypeFunction, + Function: &schemas.ChatToolFunction{ + Name: "get_weather", + Description: stringPtr("Get the current weather"), + Parameters: &schemas.ToolFunctionParameters{ + Type: "object", + Properties: &weatherProps, + Required: []string{"location"}, + }, + }, + }, + }, + }, + } + + resp, bfErr := provider.ChatCompletion(ctx, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletion with parallel tool calls failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Choices) == 0 { + t.Fatal("Expected at least one choice") + } + t.Logf("Parallel tool calls: %d", len(resp.Choices[0].Message.ToolCalls)) +} + +func TestVertexChatCompletionStream(t *testing.T) { + credentials := os.Getenv("VERTEX_CREDENTIALS") + projectID := os.Getenv("VERTEX_PROJECT_ID") + if credentials == "" || projectID == "" { + t.Fatal("VERTEX_CREDENTIALS or VERTEX_PROJECT_ID not set") + } + + provider, err := NewVertexProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Vertex provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: projectID, + Region: "us-central1", + AuthCredentials: credentials, + }, + } + + request := &schemas.BifrostChatRequest{ + Provider: schemas.Vertex, + Model: "claude-3-5-haiku", + Input: []schemas.ChatMessage{ + { + Role: "user", + Content: &schemas.ChatMessageContent{ContentStr: stringPtr("Count from 1 to 5")}, + }, + }, + Params: &schemas.ChatParameters{ + Temperature: float64Ptr(0.7), + }, + } + + streamChan, bfErr := provider.ChatCompletionStream(ctx, mockPostHookRunner, key, request) + if bfErr != nil { + t.Fatalf("ChatCompletionStream failed: %v", bfErr) + } + + count := 0 + for chunk := range streamChan { + if chunk == nil { + continue + } + if chunk.BifrostError != nil { + t.Fatalf("Stream error: %v", chunk.BifrostError) + } + count++ + } + + if count == 0 { + t.Fatal("Expected at least one chunk") + } + t.Logf("Received %d chunks", count) +} + +func TestVertexEmbedding(t *testing.T) { + credentials := os.Getenv("VERTEX_CREDENTIALS") + projectID := os.Getenv("VERTEX_PROJECT_ID") + if credentials == "" || projectID == "" { + t.Fatal("VERTEX_CREDENTIALS or VERTEX_PROJECT_ID not set") + } + + provider, err := NewVertexProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Vertex provider: %v", err) + } + + ctx := context.Background() + key := schemas.Key{ + VertexKeyConfig: &schemas.VertexKeyConfig{ + ProjectID: projectID, + Region: "us-central1", + AuthCredentials: credentials, + }, + } + + request := &schemas.BifrostEmbeddingRequest{ + Provider: schemas.Vertex, + Model: "text-embedding-004", + Input: &schemas.EmbeddingInput{Texts: []string{"Hello world"}}, + } + + resp, bfErr := provider.Embedding(ctx, key, request) + if bfErr != nil { + t.Fatalf("Embedding failed: %v", bfErr) + } + + if resp == nil { + t.Fatal("Expected non-nil response") + } + if len(resp.Data) == 0 { + t.Fatal("Expected at least one embedding") + } + if len(resp.Data[0].Embedding.EmbeddingArray) == 0 { + t.Fatal("Expected non-empty embedding vector") + } + t.Logf("Embedding dimension: %d", len(resp.Data[0].Embedding.EmbeddingArray)) +} + +func TestVertexGetProviderKey(t *testing.T) { + provider, err := NewVertexProvider(&schemas.ProviderConfig{ + NetworkConfig: schemas.NetworkConfig{ + DefaultRequestTimeoutInSeconds: 60, + }, + ConcurrencyAndBufferSize: schemas.ConcurrencyAndBufferSize{ + Concurrency: 10, + }, + }, newTestLogger()) + if err != nil { + t.Fatalf("Failed to create Vertex provider: %v", err) + } + + key := provider.GetProviderKey() + if key != schemas.Vertex { + t.Errorf("Expected provider key %s, got %s", schemas.Vertex, key) + } +} + +// Helper function for bool pointers +func boolPtr(b bool) *bool { + return &b +}