diff --git a/docs/PATTERNS.md b/docs/PATTERNS.md index 28e49d4..47bbbec 100644 --- a/docs/PATTERNS.md +++ b/docs/PATTERNS.md @@ -671,6 +671,32 @@ rag := orchestration.NewRAG( result, _ := rag.Execute(ctx, userQuestion) ``` +**Augmentation Strategies**: + +The retrieved documents are combined with the original query before being +passed to the generator. Three built-in strategies are available — pick the +one that matches your generator agent's expected input format: + +| Strategy | Output | When to use | +|----------|--------|-------------| +| `AugmentPrepend` (default) | `Context:\n\n\nQuery:\n` | Plain-text generators / chat models | +| `AugmentJSON` | `{"context": "...", "query": "..."}` | Tool-calling agents that parse structured input | +| `AugmentTemplate` | User-supplied `text/template` with `{{.Context}}` and `{{.Query}}` | Custom prompt templates with system-specific scaffolding | + +```go +// JSON augmentation for a structured generator +rag := orchestration.NewRAG("qa", runtime, "retriever", "generator", + orchestration.WithAugmentationStrategy(orchestration.AugmentJSON), +) + +// Custom template +rag := orchestration.NewRAG("qa", runtime, "retriever", "generator", + orchestration.WithAugmentationTemplate( + "Use the following context to answer:\n{{.Context}}\n\nQ: {{.Query}}\nA:", + ), +) +``` + **Metrics Tracked**: - Retrieval precision/recall - Context usage (% of retrieved context used in answer) diff --git a/internal/orchestration/rag.go b/internal/orchestration/rag.go index 03653ab..3cb6820 100644 --- a/internal/orchestration/rag.go +++ b/internal/orchestration/rag.go @@ -6,6 +6,7 @@ import ( "fmt" "maps" "strings" + "text/template" "time" "github.com/aixgo-dev/aixgo/internal/agent" @@ -15,6 +16,23 @@ import ( "go.opentelemetry.io/otel/trace" ) +// AugmentationStrategy controls how the query and retrieved documents are +// combined before being passed to the generator agent. +type AugmentationStrategy int + +const ( + // AugmentPrepend prepends retrieved context to the query (default). + // Renders as: "Context:\n\n\nQuery:\n" + AugmentPrepend AugmentationStrategy = iota + // AugmentJSON encodes the augmented input as a JSON object so the + // generator can address each field independently: + // {"context": "", "query": ""} + AugmentJSON + // AugmentTemplate renders a user-supplied text/template with the + // fields .Context and .Query. Use WithAugmentationTemplate to set it. + AugmentTemplate +) + // RAG implements Retrieval-Augmented Generation pattern. // Retrieves relevant documents from a vector store, then generates grounded answers. // Most common enterprise pattern for chatbots and Q&A systems. @@ -26,15 +44,17 @@ import ( // - Context-aware generation type RAG struct { *BaseOrchestrator - retriever string // Agent that retrieves relevant documents - generator string // Agent that generates the answer - topK int // Number of documents to retrieve - rerank bool // Whether to rerank retrieved documents - reranker string // Optional reranker agent - conversationHist []ConversationTurn // For conversational RAG - historyAgent string // Agent for managing history - queryExpander string // For multi-query RAG - keywordRetriever string // For hybrid RAG + retriever string // Agent that retrieves relevant documents + generator string // Agent that generates the answer + topK int // Number of documents to retrieve + rerank bool // Whether to rerank retrieved documents + reranker string // Optional reranker agent + conversationHist []ConversationTurn // For conversational RAG + historyAgent string // Agent for managing history + queryExpander string // For multi-query RAG + keywordRetriever string // For hybrid RAG + augmentStrategy AugmentationStrategy + augmentTemplate *template.Template // Compiled template for AugmentTemplate } // ConversationTurn represents a single turn in conversation history @@ -62,6 +82,35 @@ func WithReranker(reranker string) RAGOption { } } +// WithAugmentationStrategy selects how the query and retrieved documents are +// combined before being passed to the generator. AugmentPrepend (the default) +// produces a "Context:\n…\n\nQuery:\n…" string; AugmentJSON wraps the two +// fields as a JSON object; AugmentTemplate renders a text/template (set via +// WithAugmentationTemplate, which also flips the strategy to AugmentTemplate). +func WithAugmentationStrategy(s AugmentationStrategy) RAGOption { + return func(r *RAG) { + r.augmentStrategy = s + } +} + +// WithAugmentationTemplate supplies a text/template used to format augmented +// input. The template may reference {{.Context}} and {{.Query}} fields. +// Setting a template implicitly switches the strategy to AugmentTemplate. +// Returns the option as-is on parse failure; the error surfaces at first +// Execute call so configuration mistakes aren't lost. +func WithAugmentationTemplate(tmpl string) RAGOption { + return func(r *RAG) { + // Parse failures will be surfaced at Execute time via augmentInput. + parsed, err := template.New("rag_augment").Parse(tmpl) + if err != nil { + r.augmentTemplate = nil + } else { + r.augmentTemplate = parsed + } + r.augmentStrategy = AugmentTemplate + } +} + // NewRAG creates a new RAG orchestrator func NewRAG(name string, runtime agent.Runtime, retriever, generator string, opts ...RAGOption) *RAG { r := &RAG{ @@ -153,7 +202,11 @@ func (r *RAG) Execute(ctx context.Context, input *agent.Message) (*agent.Message } // Step 3: Generate answer with retrieved context - augmentedInput := augmentInput(input, documents) + augmentedInput, err := r.augmentInput(input, documents) + if err != nil { + span.RecordError(err) + return nil, fmt.Errorf("augmentation failed: %w", err) + } generateStart := time.Now() result, err := r.runtime.Call(ctx, r.generator, augmentedInput) @@ -395,21 +448,55 @@ func (r *RAG) hybridRetrieve(ctx context.Context, input *agent.Message) (*agent. }, nil } -// augmentInput combines the original query with retrieved documents -func augmentInput(query, documents *agent.Message) *agent.Message { +// augmentInput combines the original query with retrieved documents using the +// orchestrator's configured AugmentationStrategy. When no documents were +// retrieved (nil or empty payload), the original query is returned unchanged +// so that the generator can still respond — falling through to "ungrounded" +// generation is preferable to failing the whole pipeline on a sparse retriever. +func (r *RAG) augmentInput(query, documents *agent.Message) (*agent.Message, error) { if query == nil || query.Message == nil { - return query + return query, nil } if documents == nil || documents.Message == nil || documents.Payload == "" { // No documents retrieved, return original query - return query + return query, nil + } + + var augmentedPayload string + switch r.augmentStrategy { + case AugmentJSON: + buf, err := json.Marshal(struct { + Context string `json:"context"` + Query string `json:"query"` + }{ + Context: documents.Payload, + Query: query.Payload, + }) + if err != nil { + return nil, fmt.Errorf("encode json augmentation: %w", err) + } + augmentedPayload = string(buf) + case AugmentTemplate: + if r.augmentTemplate == nil { + return nil, fmt.Errorf("augmentation strategy is template but no template was configured (use WithAugmentationTemplate)") + } + var buf strings.Builder + err := r.augmentTemplate.Execute(&buf, struct { + Context string + Query string + }{ + Context: documents.Payload, + Query: query.Payload, + }) + if err != nil { + return nil, fmt.Errorf("render augmentation template: %w", err) + } + augmentedPayload = buf.String() + default: // AugmentPrepend + augmentedPayload = fmt.Sprintf("Context:\n%s\n\nQuery:\n%s", documents.Payload, query.Payload) } - // Create augmented message with both query and retrieved context - // Format: "Context:\n{documents}\n\nQuery:\n{query}" - augmentedPayload := fmt.Sprintf("Context:\n%s\n\nQuery:\n%s", documents.Payload, query.Payload) - // Preserve metadata from both messages metadata := make(map[string]any) if query.Metadata != nil { @@ -427,7 +514,7 @@ func augmentInput(query, documents *agent.Message) *agent.Message { Timestamp: query.Timestamp, Metadata: metadata, }, - } + }, nil } // RAG variants diff --git a/internal/orchestration/rag_test.go b/internal/orchestration/rag_test.go index 103d374..57bd43b 100644 --- a/internal/orchestration/rag_test.go +++ b/internal/orchestration/rag_test.go @@ -2,6 +2,7 @@ package orchestration import ( "context" + "encoding/json" "strings" "testing" "time" @@ -156,8 +157,10 @@ func TestRAGPattern(t *testing.T) { func TestAugmentInput(t *testing.T) { tests := []struct { name string + opts []RAGOption query *agent.Message documents *agent.Message + wantErr bool wantType string checkFunc func(t *testing.T, result *agent.Message) }{ @@ -267,11 +270,74 @@ func TestAugmentInput(t *testing.T) { } }, }, + { + name: "json strategy emits parseable JSON", + opts: []RAGOption{WithAugmentationStrategy(AugmentJSON)}, + query: &agent.Message{ + Message: &pb.Message{Payload: "What is AI?"}, + }, + documents: &agent.Message{ + Message: &pb.Message{Payload: "AI is artificial intelligence"}, + }, + wantType: "rag_augmented", + checkFunc: func(t *testing.T, result *agent.Message) { + var got struct { + Context string `json:"context"` + Query string `json:"query"` + } + if err := json.Unmarshal([]byte(result.Payload), &got); err != nil { + t.Fatalf("payload is not valid JSON: %v\npayload: %s", err, result.Payload) + } + if got.Query != "What is AI?" { + t.Errorf("query field = %q, want %q", got.Query, "What is AI?") + } + if got.Context != "AI is artificial intelligence" { + t.Errorf("context field = %q, want %q", got.Context, "AI is artificial intelligence") + } + }, + }, + { + name: "template strategy renders user template", + opts: []RAGOption{ + WithAugmentationTemplate("Use the following context to answer:\n{{.Context}}\n\nQuestion: {{.Query}}\nAnswer:"), + }, + query: &agent.Message{ + Message: &pb.Message{Payload: "What is AI?"}, + }, + documents: &agent.Message{ + Message: &pb.Message{Payload: "AI is artificial intelligence"}, + }, + wantType: "rag_augmented", + checkFunc: func(t *testing.T, result *agent.Message) { + want := "Use the following context to answer:\nAI is artificial intelligence\n\nQuestion: What is AI?\nAnswer:" + if result.Payload != want { + t.Errorf("payload mismatch\n got: %q\n want: %q", result.Payload, want) + } + }, + }, + { + name: "template strategy without template configured returns error", + opts: []RAGOption{WithAugmentationStrategy(AugmentTemplate)}, + query: &agent.Message{ + Message: &pb.Message{Payload: "q"}, + }, + documents: &agent.Message{ + Message: &pb.Message{Payload: "d"}, + }, + wantErr: true, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := augmentInput(tt.query, tt.documents) + r := NewRAG("test", NewMockRuntime(), "retriever", "generator", tt.opts...) + result, err := r.augmentInput(tt.query, tt.documents) + if (err != nil) != tt.wantErr { + t.Fatalf("augmentInput() err = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } if tt.wantType != "" && result != nil && result.Type != tt.wantType { t.Errorf("Type = %s, want %s", result.Type, tt.wantType) }