diff --git a/go/ai/document.go b/go/ai/document.go index 09e80b1f59..ac9dcd3b74 100644 --- a/go/ai/document.go +++ b/go/ai/document.go @@ -107,14 +107,12 @@ func NewCustomPart(customData map[string]any) *Part { } // NewReasoningPart returns a Part containing reasoning text -func NewReasoningPart(text string, signature []byte) *Part { +func NewReasoningPart(text string, metadata map[string]any) *Part { return &Part{ Kind: PartReasoning, ContentType: "plain/text", Text: text, - Metadata: map[string]any{ - "signature": signature, - }, + Metadata: metadata, } } diff --git a/go/go.mod b/go/go.mod index ed667e6b24..96f87fca65 100644 --- a/go/go.mod +++ b/go/go.mod @@ -17,6 +17,7 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/trace v1.27.0 github.com/anthropics/anthropic-sdk-go v1.4.0 github.com/blues/jsonata-go v1.5.4 + github.com/cohesion-org/deepseek-go v1.3.2 github.com/goccy/go-yaml v1.17.1 github.com/google/dotprompt/go v0.0.0-20250611200215-bb73406b05ca github.com/google/go-cmp v0.7.0 @@ -52,6 +53,8 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect + github.com/joho/godotenv v1.5.1 // indirect + github.com/ollama/ollama v0.6.5 // indirect go.opencensus.io v0.24.0 // indirect ) diff --git a/go/go.sum b/go/go.sum index ad59b6479c..a32916fd6f 100644 --- a/go/go.sum +++ b/go/go.sum @@ -69,6 +69,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f h1:C5bqEmzEPLsHm9Mv73lSE9e9bKV23aB1vxOsmZrkl3k= github.com/cncf/xds/go v0.0.0-20250326154945-ae57f3c0d45f/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/cohesion-org/deepseek-go v1.3.2 h1:WTZ/2346KFYca+n+DL5p+Ar1RQxF2w/wGkU4jDvyXaQ= +github.com/cohesion-org/deepseek-go v1.3.2/go.mod h1:bOVyKj38r90UEYZFrmJOzJKPxuAh8sIzHOCnLOpiXeI= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -253,6 +255,8 @@ github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/ github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= +github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= +github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/karrick/godirwalk v1.8.0/go.mod h1:H5KPZjojv4lE+QYImBI8xVtrBRgYrIVsaRPx4tDPEn4= @@ -292,6 +296,8 @@ github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJ github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= +github.com/ollama/ollama v0.6.5 h1:vXKkVX57ql/1ZzMw4SVK866Qfd6pjwEcITVyEpF0QXQ= +github.com/ollama/ollama v0.6.5/go.mod h1:pGgtoNyc9DdM6oZI6yMfI6jTk2Eh4c36c2GpfQCH7PY= github.com/openai/openai-go v1.8.2 h1:UqSkJ1vCOPUpz9Ka5tS0324EJFEuOvMc+lA/EarJWP8= github.com/openai/openai-go v1.8.2/go.mod h1:g461MYGXEXBVdV5SaR/5tNzNbSfwTBBefwc+LlDCK0Y= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= diff --git a/go/plugins/deepseek/deepseek.go b/go/plugins/deepseek/deepseek.go new file mode 100644 index 0000000000..f65267257c --- /dev/null +++ b/go/plugins/deepseek/deepseek.go @@ -0,0 +1,431 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package deepseek + +import ( + "context" + "errors" + "fmt" + "os" + "sync" + + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/core" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/internal" + + deepseek "github.com/cohesion-org/deepseek-go" +) + +const ( + provider = "deepseek" + deepseekLabelPrefix = "DeepSeek" +) + +type DeepSeek struct { + APIKey string // DeepSeek API key, if not provided env var DEEPSEEK_API_KEY is looked up + + dsclient *deepseek.Client + mu sync.Mutex + initted bool +} + +// Name returns the name of the plugin +func (d *DeepSeek) Name() string { + return provider +} + +// Init initializes the DeepSeek plugin and all its known models +func (d *DeepSeek) Init(ctx context.Context, g *genkit.Genkit) (err error) { + if d == nil { + d = &DeepSeek{} + } + d.mu.Lock() + defer d.mu.Unlock() + if d.initted { + return errors.New("plugin already initialized") + } + defer func() { + if err != nil { + err = fmt.Errorf("DeepSeek.Init: %w", err) + } + }() + + apiKey := d.APIKey + if apiKey == "" { + apiKey = os.Getenv("DEEPSEEK_API_KEY") + if apiKey == "" { + return fmt.Errorf("DeepSeek requires setting DEEPSEEK_API_KEY environment variable") + } + } + + dsc := deepseek.NewClient(apiKey) + if dsc == nil { + return errors.New("unable to create deepseek client") + } + d.dsclient = dsc + d.initted = true + + return nil +} + +// Model returns the [ai.Model] with the given name +func (d *DeepSeek) Model(g *genkit.Genkit, name string) ai.Model { + return genkit.LookupModel(g, provider, name) +} + +// DefineModel defines an unknown model with the given name +func (d *DeepSeek) DefineModel(g *genkit.Genkit, name string, info ai.ModelInfo) (ai.Model, error) { + return defineModel(g, *d.dsclient, provider, name, info), nil +} + +// ListActions lists all resolvable actions by the plugin +func (d *DeepSeek) ListActions(ctx context.Context) []core.ActionDesc { + actions := []core.ActionDesc{} + models, err := listModels(ctx, *d.dsclient) + if err != nil { + return nil + } + + for _, name := range models { + metadata := map[string]any{ + "model": map[string]any{ + "supports": map[string]any{ + "media": false, // official deepseek models do not support media content + "multiturn": true, + "systemRole": true, + "tools": true, + "toolChoice": true, + "constrained": "no-tools", + }, + "versions": []string{}, + "stage": string(ai.ModelStageStable), + }, + } + metadata["label"] = fmt.Sprintf("%s - %s", deepseekLabelPrefix, name) + actions = append(actions, core.ActionDesc{ + Type: core.ActionTypeModel, + Name: fmt.Sprintf("%s/%s", provider, name), + Key: fmt.Sprintf("/%s/%s/%s", core.ActionTypeModel, provider, name), + Metadata: metadata, + }) + + } + + return actions +} + +// ResolveAction resolves all available actions from the plugin +func (d *DeepSeek) ResolveAction(g *genkit.Genkit, atype core.ActionType, name string) error { + switch atype { + case core.ActionTypeModel: + defineModel(g, *d.dsclient, provider, name, ai.ModelInfo{ + Label: fmt.Sprintf("%s - %s", deepseekLabelPrefix, name), + Stage: ai.ModelStageStable, + Versions: []string{}, + Supports: &internal.Multimodal, + }) + } + return nil +} + +// defineModel defines a model in the Genkit core +func defineModel(g *genkit.Genkit, client deepseek.Client, provider, name string, info ai.ModelInfo) ai.Model { + return genkit.DefineModel(g, provider, name, &info, func( + ctx context.Context, + input *ai.ModelRequest, + cb func(context.Context, *ai.ModelResponseChunk) error, + ) (*ai.ModelResponse, error) { + return generate(ctx, client, name, input, cb) + }) +} + +// listModels fetches the official deepseek models from the API +func listModels(ctx context.Context, client deepseek.Client) ([]string, error) { + apiModels, err := deepseek.ListAllModels(&client, ctx) + if err != nil { + return nil, err + } + models := []string{} + for _, m := range apiModels.Data { + if m.Object != "model" && m.OwnedBy != "deepseek" { + continue + } + } + + return models, nil +} + +func generate(ctx context.Context, client deepseek.Client, model string, input *ai.ModelRequest, cb func(context.Context, *ai.ModelResponseChunk) error, +) (*ai.ModelResponse, error) { + req, err := toDeepSeekRequest(input) + if err != nil { + return nil, err + } + + req.Model = model + + // constrained generation + hasOutput := input.Output != nil + isJsonFormat := hasOutput && input.Output.Format == "json" + isJsonContentType := hasOutput && input.Output.ContentType == "application/json" + req.JSONMode = isJsonFormat || isJsonContentType + if req.JSONMode { + req.ResponseFormat = &deepseek.ResponseFormat{ + Type: "json_object", + } + } + + // no stream request + if cb == nil { + resp, err := client.CreateChatCompletion(ctx, req) + if err != nil { + return nil, err + } + r := translateResponse(resp) + r.Request = input + + return r, nil + } + + return nil, nil +} + +func toDeepSeekRequest(input *ai.ModelRequest) (*deepseek.ChatCompletionRequest, error) { + req, err := configFromRequest(input) + if err != nil { + return nil, err + } + + // Genkit primitive fields must be used instead of deepseek config fields + // i.e.: system prompt, tools, cached content, response schema + if req.Model != "" { + return nil, errors.New("model must be set using Genkit feature: ai.WithModelName() or ai.WithModel()") + } + if req.Messages != nil { + return nil, errors.New("messages must be set using Genkit feature: ai.WithMessages()") + } + if req.ResponseFormat != nil { + return nil, errors.New("response format must be set using Genkit feature: ai.WithOutputType()") + } + if req.ToolChoice != nil { + return nil, errors.New("tool choice must be set using Genkit feature: ai.WithToolChoice()") + } + if req.Tools != nil { + return nil, errors.New("tools must be set using Genkit feature: ai.WithTools()") + } + + messages := []deepseek.ChatCompletionMessage{} + for _, m := range input.Messages { + if m.Role == ai.RoleSystem { + messages = append(messages, deepseek.ChatCompletionMessage{ + Role: toDeepSeekRole(m.Role), + Content: m.Text(), + }) + } else if m.Content[len(m.Content)-1].IsToolResponse() { + parts, err := toDeepSeekParts(m.Content, m.Role) + if err != nil { + return nil, err + } + messages = append(messages, parts...) + } else { + parts, err := toDeepSeekParts(m.Content, m.Role) + if err != nil { + return nil, err + } + messages = append(messages, parts...) + } + } + req.Messages = messages + + if len(input.Tools) > 0 { + tools, err := toDeepSeekTool(input.Tools) + if err != nil { + return nil, err + } + req.Tools = tools + + tc := toDeepSeekToolChoice(input.ToolChoice) + req.ToolChoice = tc + } + + return req, nil +} + +// toDeepSeekRole translates an [ai.Role] to a DeepSeek role +func toDeepSeekRole(role ai.Role) string { + var r string + switch role { + case ai.RoleSystem: + r = deepseek.ChatMessageRoleSystem + case ai.RoleModel: + r = deepseek.ChatMessageRoleAssistant + case ai.RoleTool: + r = deepseek.ChatMessageRoleTool + default: + r = deepseek.ChatMessageRoleUser + } + return r +} + +// toDeepSeekParts translates an array of [ai.Part] to [deepseek.ChatCompletionMessage] +func toDeepSeekParts(parts []*ai.Part, role ai.Role) ([]deepseek.ChatCompletionMessage, error) { + res := make([]deepseek.ChatCompletionMessage, 0, len(parts)) + for _, p := range parts { + part, err := toDeepSeekPart(p, role) + if err != nil { + return nil, err + } + res = append(res, part) + } + + return res, nil +} + +// toDeepSeekParts translates an [ai.Part] to [deepseek.ChatCompletionMessage] +func toDeepSeekPart(p *ai.Part, r ai.Role) (deepseek.ChatCompletionMessage, error) { + role := toDeepSeekRole(r) + switch { + case p.IsReasoning(): + content := "" + if p.Metadata != nil { + if c, ok := p.Metadata["content"].(string); ok { + content = c + } + } + return deepseek.ChatCompletionMessage{ + Role: toDeepSeekRole(r), + Content: content, + ReasoningContent: p.Text, + }, nil + case p.IsText(): + return deepseek.ChatCompletionMessage{ + Content: p.Text, + Role: role, + }, nil + default: + panic("unknown part type in the request") + } +} + +// toDeepSeekTool translates a slice of [ai.ToolDefinition] to a slice of [deepseek.Tool] +func toDeepSeekTool(inTools []*ai.ToolDefinition) ([]deepseek.Tool, error) { + tools := []deepseek.Tool{} + for _, t := range inTools { + tools = append(tools, deepseek.Tool{ + Type: "function", + Function: deepseek.Function{ + Name: t.Name, + Description: t.Description, + Parameters: &deepseek.FunctionParameters{ + Type: "object", + Properties: t.InputSchema, + }, + }, + }) + } + return tools, nil +} + +// toDeepSeekToolChoice translates an [ai.ToolChoice] to a DeepSeek Tool choice +func toDeepSeekToolChoice(choice ai.ToolChoice) string { + switch choice { + case ai.ToolChoiceAuto: + return "auto" + case ai.ToolChoiceNone: + return "none" + case ai.ToolChoiceRequired: + return "required" + } + return "none" +} + +// translateResponse translates a [deepseek.ChatCompletionResponse] to an [ai.ModelResponse] +func translateResponse(resp *deepseek.ChatCompletionResponse) *ai.ModelResponse { + var r *ai.ModelResponse + if len(resp.Choices) > 0 { + r = translateCandidate(resp.Choices[0]) + } else { + r = &ai.ModelResponse{} + } + + if r.Usage == nil { + r.Usage = &ai.GenerationUsage{} + } + + r.Usage.InputTokens = resp.Usage.PromptTokens + r.Usage.CachedContentTokens = resp.Usage.PromptCacheHitTokens + r.Usage.OutputTokens = resp.Usage.CompletionTokens + r.Usage.TotalTokens = resp.Usage.TotalTokens + + return r +} + +// translateCandidate translates a [deepseek.Choice] to an [ai.ModelResponse] +func translateCandidate(cand deepseek.Choice) *ai.ModelResponse { + m := &ai.ModelResponse{} + switch cand.FinishReason { + case "stop": + m.FinishReason = ai.FinishReasonStop + case "length": + m.FinishReason = ai.FinishReasonLength + case "content_filter": + m.FinishReason = ai.FinishReasonBlocked + case "tool_calls": + m.FinishReason = ai.FinishReasonOther + case "insufficient_system_resource": + m.FinishReason = ai.FinishReasonOther + default: + m.FinishReason = ai.FinishReasonOther + } + + msg := &ai.Message{} + msg.Role = ai.Role(cand.Message.Role) + + var p *ai.Part + // there's only one part in [deepseek.Choice], it could be either text or reasoning + if cand.Message.ReasoningContent != "" { + p = ai.NewReasoningPart(cand.Message.ReasoningContent, map[string]any{ + "content": cand.Message.Content, + }) + } else { + p = ai.NewTextPart(cand.Message.Content) + } + + msg.Content = append(msg.Content, p) + m.Message = msg + return m +} + +// configFromRequest ensures a valid DeepSeek configuration is used +func configFromRequest(input *ai.ModelRequest) (*deepseek.ChatCompletionRequest, error) { + var config deepseek.ChatCompletionRequest + + switch cfg := input.Config.(type) { + case deepseek.ChatCompletionRequest: + config = cfg + case *deepseek.ChatCompletionRequest: + config = *cfg + case map[string]any: + if err := internal.MapToStruct(cfg, config); err != nil { + return nil, err + } + case nil: + default: + return nil, fmt.Errorf("unexpected config type: %T", input.Config) + } + + return &config, nil +} diff --git a/go/plugins/googlegenai/gemini.go b/go/plugins/googlegenai/gemini.go index aa19387183..c690a38cdf 100644 --- a/go/plugins/googlegenai/gemini.go +++ b/go/plugins/googlegenai/gemini.go @@ -625,7 +625,9 @@ func translateCandidate(cand *genai.Candidate) *ai.ModelResponse { partFound := 0 if part.Thought { - p = ai.NewReasoningPart(part.Text, part.ThoughtSignature) + p = ai.NewReasoningPart(part.Text, map[string]any{ + "signature": part.ThoughtSignature, + }) partFound++ } if part.Text != "" && !part.Thought { @@ -712,7 +714,6 @@ func toGeminiParts(parts []*ai.Part) ([]*genai.Part, error) { // toGeminiPart converts a [ai.Part] to a [genai.Part]. func toGeminiPart(p *ai.Part) (*genai.Part, error) { - switch { case p.IsReasoning(): // TODO: go-genai does not support genai.NewPartFromThought() diff --git a/go/plugins/internal/models.go b/go/plugins/internal/models.go index 839d13001f..255f3f88ad 100644 --- a/go/plugins/internal/models.go +++ b/go/plugins/internal/models.go @@ -15,7 +15,11 @@ // Package internal contains code that is common to all models package internal -import "github.com/firebase/genkit/go/ai" +import ( + "encoding/json" + + "github.com/firebase/genkit/go/ai" +) var ( // BasicText describes model capabilities for text-only models. @@ -37,3 +41,12 @@ var ( Constrained: ai.ConstrainedSupportNoTools, } ) + +// MapToStruct unmarshals a map[String]any to the expected type +func MapToStruct(m map[string]any, v any) error { + jsonData, err := json.Marshal(m) + if err != nil { + return err + } + return json.Unmarshal(jsonData, v) +} diff --git a/go/samples/deepseek/main.go b/go/samples/deepseek/main.go new file mode 100644 index 0000000000..71c797ad25 --- /dev/null +++ b/go/samples/deepseek/main.go @@ -0,0 +1,51 @@ +// Copyright 2025 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "log" + + deepseek_ai "github.com/cohesion-org/deepseek-go" + "github.com/firebase/genkit/go/ai" + "github.com/firebase/genkit/go/genkit" + "github.com/firebase/genkit/go/plugins/deepseek" +) + +func main() { + ctx := context.Background() + + g, err := genkit.Init(ctx, genkit.WithPlugins(&deepseek.DeepSeek{})) + if err != nil { + log.Fatal(err) + } + + genkit.DefineFlow(g, "basic-prompt", func(ctx context.Context, input any) (string, error) { + resp, err := genkit.Generate(ctx, g, + ai.WithConfig(deepseek_ai.ChatCompletionRequest{ + Temperature: float32(1.0), + }), + ) + if err != nil { + return "", err + } + text := resp.Text() + return text, nil + }) + + <-ctx.Done() +}