Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions sdk-go/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ func GetModel(provider, modelID string) llmsdk.LanguageModel {
}
return google.NewGoogleModel(modelID, google.GoogleModelOptions{
APIKey: apiKey,
// ProviderType: google.ProviderTypeVertexAI,
// AccessToken: "your-access-token",
// Location: "us-central1",
// ProjectID: "your-project-id",
})
default:
panic(fmt.Sprintf("Unsupported provider: %s", provider))
Expand Down
8 changes: 7 additions & 1 deletion sdk-go/accumulator.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ func mergeDelta(existing accumulatedData, delta ContentDelta) error {
if toolCallPartDelta.ID != nil {
existingData.ID = toolCallPartDelta.ID
}
if toolCallPartDelta.ThoughtSignature != nil {
existingData.ThoughtSignature = toolCallPartDelta.ThoughtSignature
}
case existing.Image != nil:
imagePartDelta := delta.Part.ImagePartDelta
if imagePartDelta == nil {
Expand Down Expand Up @@ -294,6 +297,9 @@ func createToolCallPart(data *ToolCallPartDelta, index int) (Part, error) {
if data.ID != nil {
opts = append(opts, WithToolCallPartID(*data.ID))
}
if data.ThoughtSignature != nil {
opts = append(opts, WithToolCallThoughtSignature(*data.ThoughtSignature))
}

toolCallPart := NewToolCallPart(*data.ToolCallID, *data.ToolName, args, opts...)
return toolCallPart, nil
Expand Down Expand Up @@ -353,7 +359,7 @@ func createAudioPart(data *accumulatedAudioData) (Part, error) {

// createReasoningPart creates a reasoning part from accumulated reasoning data
func createReasoningPart(data *ReasoningPartDelta) Part {
var opts []ReasoingPartOption
var opts []ReasoningPartOption
if data.Signature != nil {
opts = append(opts, WithReasoningSignature(*data.Signature))
}
Expand Down
2 changes: 1 addition & 1 deletion sdk-go/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ func mapAnthropicContentBlock(block anthropicapi.ContentBlock) (*llmsdk.Part, er
return &part, nil

case block.ResponseThinkingBlock != nil:
opts := []llmsdk.ReasoingPartOption{}
opts := []llmsdk.ReasoningPartOption{}
if block.ResponseThinkingBlock.Signature != "" {
opts = append(opts, llmsdk.WithReasoningSignature(block.ResponseThinkingBlock.Signature))
}
Expand Down
139 changes: 100 additions & 39 deletions sdk-go/google/google.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package google

import (
"cmp"
"context"
"encoding/json"
"fmt"
"maps"
"net/http"
"strings"

Expand All @@ -20,51 +22,75 @@ import (

const Provider = "google"

type ProviderType string

const (
ProviderTypeGoogleAI ProviderType = "google-ai"
ProviderTypeVertexAI ProviderType = "vertex-ai"
)

type GoogleModelOptions struct {
BaseURL string
APIKey string
APIVersion string
Headers map[string]string
HTTPClient *http.Client
BaseURL string
APIKey string
APIVersion string
Headers map[string]string
HTTPClient *http.Client
ProviderType ProviderType
AccessToken string
ProjectID string
Location string
}

type GoogleModel struct {
baseURL string
apiKey string
apiVersion string
modelID string
client *http.Client
metadata *llmsdk.LanguageModelMetadata
headers map[string]string
baseURL string
apiKey string
apiVersion string
modelID string
client *http.Client
metadata *llmsdk.LanguageModelMetadata
headers map[string]string
providerType ProviderType
accessToken string
projectID string
location string
}

func NewGoogleModel(modelID string, options GoogleModelOptions) *GoogleModel {
baseURL := "https://generativelanguage.googleapis.com"
if options.BaseURL != "" {
baseURL = options.BaseURL
}
apiVersion := "v1beta"
if options.APIVersion != "" {
apiVersion = options.APIVersion
}

client := options.HTTPClient
if client == nil {
client = &http.Client{}
providerType := cmp.Or(options.ProviderType, ProviderTypeGoogleAI)

baseURL := options.BaseURL
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
if providerType == ProviderTypeVertexAI {
if options.Location == "" || options.Location == "global" {
baseURL = "https://aiplatform.googleapis.com"
} else {
baseURL = "https://" + options.Location + "-aiplatform.googleapis.com"
}
}
}

headers := map[string]string{}
for k, v := range options.Headers {
headers[k] = v
apiVersion := options.APIVersion
if apiVersion == "" {
if providerType == ProviderTypeVertexAI {
apiVersion = "v1beta1"
} else {
apiVersion = "v1beta"
}
}

return &GoogleModel{
baseURL: baseURL,
apiKey: options.APIKey,
apiVersion: apiVersion,
modelID: modelID,
client: client,
headers: headers,
client: cmp.Or(options.HTTPClient, &http.Client{}),
headers: maps.Clone(options.Headers),

providerType: providerType,
accessToken: options.AccessToken,
projectID: options.ProjectID,
location: options.Location,
}
}

Expand All @@ -86,17 +112,44 @@ func (m *GoogleModel) Metadata() *llmsdk.LanguageModelMetadata {
}

func (m *GoogleModel) requestHeaders() map[string]string {
headers := map[string]string{
"x-goog-api-key": m.apiKey,
headers := maps.Clone(m.headers)
if headers == nil {
headers = make(map[string]string, 1)
}

for k, v := range m.headers {
headers[k] = v
if m.accessToken != "" {
headers["Authorization"] = "Bearer " + m.accessToken
} else if m.apiKey != "" {
headers["x-goog-api-key"] = m.apiKey
}

return headers
}

func (m *GoogleModel) buildEndpointURL(action string) string {
var b strings.Builder
b.WriteString(m.baseURL)
b.WriteByte('/')
b.WriteString(m.apiVersion)

if m.providerType == ProviderTypeVertexAI {
if m.projectID != "" {
b.WriteString("/projects/")
b.WriteString(m.projectID)
b.WriteString("/locations/")
b.WriteString(cmp.Or(m.location, "global"))
}
b.WriteString("/publishers/google")
}

b.WriteString("/models/")
b.WriteString(m.modelID)
b.WriteByte(':')
b.WriteString(action)

return b.String()
}

func (m *GoogleModel) Generate(ctx context.Context, input *llmsdk.LanguageModelInput) (*llmsdk.ModelResponse, error) {
return tracing.TraceGenerate(ctx, string(Provider), m.modelID, input, func(ctx context.Context) (*llmsdk.ModelResponse, error) {
params, err := convertToGenerateContentParameters(input, m.modelID)
Expand All @@ -105,7 +158,7 @@ func (m *GoogleModel) Generate(ctx context.Context, input *llmsdk.LanguageModelI
}

response, err := clientutils.DoJSON[googleapi.GenerateContentResponse](ctx, m.client, clientutils.JSONRequestConfig{
URL: fmt.Sprintf("%s/%s/models/%s:generateContent", m.baseURL, m.apiVersion, m.modelID),
URL: m.buildEndpointURL("generateContent"),
Headers: m.requestHeaders(),
Body: params,
})
Expand Down Expand Up @@ -149,7 +202,7 @@ func (m *GoogleModel) Stream(ctx context.Context, input *llmsdk.LanguageModelInp
}

sseStream, err := clientutils.DoSSE[googleapi.GenerateContentResponse](ctx, m.client, clientutils.SSERequestConfig{
URL: fmt.Sprintf("%s/%s/models/%s:streamGenerateContent?alt=sse", m.baseURL, m.apiVersion, m.modelID),
URL: m.buildEndpointURL("streamGenerateContent?alt=sse"),
Headers: m.requestHeaders(),
Body: params,
})
Expand Down Expand Up @@ -345,13 +398,17 @@ func convertToGoogleParts(part llmsdk.Part) ([]googleapi.Part, error) {
parts,
), nil
case part.ToolCallPart != nil:
return []googleapi.Part{{
googlePart := googleapi.Part{
FunctionCall: &googleapi.FunctionCall{
Name: &part.ToolCallPart.ToolName,
Args: part.ToolCallPart.Args,
Id: &part.ToolCallPart.ToolCallID,
},
}}, nil
}
if part.ToolCallPart.ThoughtSignature != nil {
googlePart.ThoughtSignature = part.ToolCallPart.ThoughtSignature
}
return []googleapi.Part{googlePart}, nil
case part.ToolResultPart != nil:
response, err := convertToGoogleFunctionResponseResponse(part.ToolResultPart.Content, part.ToolResultPart.IsError)
if err != nil {
Expand Down Expand Up @@ -493,7 +550,7 @@ func mapGoogleContent(parts []googleapi.Part) ([]llmsdk.Part, error) {
if part.Text != nil {
text = *part.Text
}
opts := []llmsdk.ReasoingPartOption{}
opts := []llmsdk.ReasoningPartOption{}
if part.ThoughtSignature != nil {
opts = append(opts, llmsdk.WithReasoningSignature(*part.ThoughtSignature))
}
Expand Down Expand Up @@ -534,7 +591,11 @@ func mapGoogleContent(parts []googleapi.Part) ([]llmsdk.Part, error) {
} else {
toolCallID = fmt.Sprintf("call_%s", randutil.String(10))
}
toolCallPart := llmsdk.NewToolCallPart(toolCallID, *part.FunctionCall.Name, json.RawMessage(part.FunctionCall.Args))
opts := []llmsdk.ToolCallPartOption{}
if part.ThoughtSignature != nil {
opts = append(opts, llmsdk.WithToolCallThoughtSignature(*part.ThoughtSignature))
}
toolCallPart := llmsdk.NewToolCallPart(toolCallID, *part.FunctionCall.Name, json.RawMessage(part.FunctionCall.Args), opts...)
toolCallPart.ToolCallPart.Args = part.FunctionCall.Args
result = append(result, toolCallPart)
continue
Expand Down
Loading