Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/PATTERNS.md
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,32 @@ rag := orchestration.NewRAG(
result, _ := rag.Execute(ctx, userQuestion)
```

**Augmentation Strategies**:

The retrieved documents are combined with the original query before being
passed to the generator. Three built-in strategies are available — pick the
one that matches your generator agent's expected input format:

| Strategy | Output | When to use |
|----------|--------|-------------|
| `AugmentPrepend` (default) | `Context:\n<docs>\n\nQuery:\n<query>` | Plain-text generators / chat models |
| `AugmentJSON` | `{"context": "...", "query": "..."}` | Tool-calling agents that parse structured input |
| `AugmentTemplate` | User-supplied `text/template` with `{{.Context}}` and `{{.Query}}` | Custom prompt templates with system-specific scaffolding |

```go
// JSON augmentation for a structured generator
rag := orchestration.NewRAG("qa", runtime, "retriever", "generator",
orchestration.WithAugmentationStrategy(orchestration.AugmentJSON),
)

// Custom template
rag := orchestration.NewRAG("qa", runtime, "retriever", "generator",
orchestration.WithAugmentationTemplate(
"Use the following context to answer:\n{{.Context}}\n\nQ: {{.Query}}\nA:",
),
)
```

**Metrics Tracked**:
- Retrieval precision/recall
- Context usage (% of retrieved context used in answer)
Expand Down
125 changes: 106 additions & 19 deletions internal/orchestration/rag.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"maps"
"strings"
"text/template"
"time"

"github.com/aixgo-dev/aixgo/internal/agent"
Expand All @@ -15,6 +16,23 @@ import (
"go.opentelemetry.io/otel/trace"
)

// AugmentationStrategy controls how the query and retrieved documents are
// combined before being passed to the generator agent.
type AugmentationStrategy int

const (
// AugmentPrepend prepends retrieved context to the query (default).
// Renders as: "Context:\n<documents>\n\nQuery:\n<query>"
AugmentPrepend AugmentationStrategy = iota
// AugmentJSON encodes the augmented input as a JSON object so the
// generator can address each field independently:
// {"context": "<documents>", "query": "<query>"}
AugmentJSON
// AugmentTemplate renders a user-supplied text/template with the
// fields .Context and .Query. Use WithAugmentationTemplate to set it.
AugmentTemplate
)

// RAG implements Retrieval-Augmented Generation pattern.
// Retrieves relevant documents from a vector store, then generates grounded answers.
// Most common enterprise pattern for chatbots and Q&A systems.
Expand All @@ -26,15 +44,17 @@ import (
// - Context-aware generation
type RAG struct {
*BaseOrchestrator
retriever string // Agent that retrieves relevant documents
generator string // Agent that generates the answer
topK int // Number of documents to retrieve
rerank bool // Whether to rerank retrieved documents
reranker string // Optional reranker agent
conversationHist []ConversationTurn // For conversational RAG
historyAgent string // Agent for managing history
queryExpander string // For multi-query RAG
keywordRetriever string // For hybrid RAG
retriever string // Agent that retrieves relevant documents
generator string // Agent that generates the answer
topK int // Number of documents to retrieve
rerank bool // Whether to rerank retrieved documents
reranker string // Optional reranker agent
conversationHist []ConversationTurn // For conversational RAG
historyAgent string // Agent for managing history
queryExpander string // For multi-query RAG
keywordRetriever string // For hybrid RAG
augmentStrategy AugmentationStrategy
augmentTemplate *template.Template // Compiled template for AugmentTemplate
}

// ConversationTurn represents a single turn in conversation history
Expand Down Expand Up @@ -62,6 +82,35 @@ func WithReranker(reranker string) RAGOption {
}
}

// WithAugmentationStrategy selects how the query and retrieved documents are
// combined before being passed to the generator. AugmentPrepend (the default)
// produces a "Context:\n…\n\nQuery:\n…" string; AugmentJSON wraps the two
// fields as a JSON object; AugmentTemplate renders a text/template (set via
// WithAugmentationTemplate, which also flips the strategy to AugmentTemplate).
func WithAugmentationStrategy(s AugmentationStrategy) RAGOption {
return func(r *RAG) {
r.augmentStrategy = s
}
}

// WithAugmentationTemplate supplies a text/template used to format augmented
// input. The template may reference {{.Context}} and {{.Query}} fields.
// Setting a template implicitly switches the strategy to AugmentTemplate.
// Returns the option as-is on parse failure; the error surfaces at first
// Execute call so configuration mistakes aren't lost.
func WithAugmentationTemplate(tmpl string) RAGOption {
return func(r *RAG) {
// Parse failures will be surfaced at Execute time via augmentInput.
parsed, err := template.New("rag_augment").Parse(tmpl)
if err != nil {
r.augmentTemplate = nil
} else {
r.augmentTemplate = parsed
}
r.augmentStrategy = AugmentTemplate
}
}

// NewRAG creates a new RAG orchestrator
func NewRAG(name string, runtime agent.Runtime, retriever, generator string, opts ...RAGOption) *RAG {
r := &RAG{
Expand Down Expand Up @@ -153,7 +202,11 @@ func (r *RAG) Execute(ctx context.Context, input *agent.Message) (*agent.Message
}

// Step 3: Generate answer with retrieved context
augmentedInput := augmentInput(input, documents)
augmentedInput, err := r.augmentInput(input, documents)
if err != nil {
span.RecordError(err)
return nil, fmt.Errorf("augmentation failed: %w", err)
}

generateStart := time.Now()
result, err := r.runtime.Call(ctx, r.generator, augmentedInput)
Expand Down Expand Up @@ -395,21 +448,55 @@ func (r *RAG) hybridRetrieve(ctx context.Context, input *agent.Message) (*agent.
}, nil
}

// augmentInput combines the original query with retrieved documents
func augmentInput(query, documents *agent.Message) *agent.Message {
// augmentInput combines the original query with retrieved documents using the
// orchestrator's configured AugmentationStrategy. When no documents were
// retrieved (nil or empty payload), the original query is returned unchanged
// so that the generator can still respond — falling through to "ungrounded"
// generation is preferable to failing the whole pipeline on a sparse retriever.
func (r *RAG) augmentInput(query, documents *agent.Message) (*agent.Message, error) {
if query == nil || query.Message == nil {
return query
return query, nil
}

if documents == nil || documents.Message == nil || documents.Payload == "" {
// No documents retrieved, return original query
return query
return query, nil
}

var augmentedPayload string
switch r.augmentStrategy {
case AugmentJSON:
buf, err := json.Marshal(struct {
Context string `json:"context"`
Query string `json:"query"`
}{
Context: documents.Payload,
Query: query.Payload,
})
if err != nil {
return nil, fmt.Errorf("encode json augmentation: %w", err)
}
augmentedPayload = string(buf)
case AugmentTemplate:
if r.augmentTemplate == nil {
return nil, fmt.Errorf("augmentation strategy is template but no template was configured (use WithAugmentationTemplate)")
}
var buf strings.Builder
err := r.augmentTemplate.Execute(&buf, struct {
Context string
Query string
}{
Context: documents.Payload,
Query: query.Payload,
})
if err != nil {
return nil, fmt.Errorf("render augmentation template: %w", err)
}
augmentedPayload = buf.String()
default: // AugmentPrepend
augmentedPayload = fmt.Sprintf("Context:\n%s\n\nQuery:\n%s", documents.Payload, query.Payload)
}

// Create augmented message with both query and retrieved context
// Format: "Context:\n{documents}\n\nQuery:\n{query}"
augmentedPayload := fmt.Sprintf("Context:\n%s\n\nQuery:\n%s", documents.Payload, query.Payload)

// Preserve metadata from both messages
metadata := make(map[string]any)
if query.Metadata != nil {
Expand All @@ -427,7 +514,7 @@ func augmentInput(query, documents *agent.Message) *agent.Message {
Timestamp: query.Timestamp,
Metadata: metadata,
},
}
}, nil
}

// RAG variants
Expand Down
68 changes: 67 additions & 1 deletion internal/orchestration/rag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package orchestration

import (
"context"
"encoding/json"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -156,8 +157,10 @@ func TestRAGPattern(t *testing.T) {
func TestAugmentInput(t *testing.T) {
tests := []struct {
name string
opts []RAGOption
query *agent.Message
documents *agent.Message
wantErr bool
wantType string
checkFunc func(t *testing.T, result *agent.Message)
}{
Expand Down Expand Up @@ -267,11 +270,74 @@ func TestAugmentInput(t *testing.T) {
}
},
},
{
name: "json strategy emits parseable JSON",
opts: []RAGOption{WithAugmentationStrategy(AugmentJSON)},
query: &agent.Message{
Message: &pb.Message{Payload: "What is AI?"},
},
documents: &agent.Message{
Message: &pb.Message{Payload: "AI is artificial intelligence"},
},
wantType: "rag_augmented",
checkFunc: func(t *testing.T, result *agent.Message) {
var got struct {
Context string `json:"context"`
Query string `json:"query"`
}
if err := json.Unmarshal([]byte(result.Payload), &got); err != nil {
t.Fatalf("payload is not valid JSON: %v\npayload: %s", err, result.Payload)
}
if got.Query != "What is AI?" {
t.Errorf("query field = %q, want %q", got.Query, "What is AI?")
}
if got.Context != "AI is artificial intelligence" {
t.Errorf("context field = %q, want %q", got.Context, "AI is artificial intelligence")
}
},
},
{
name: "template strategy renders user template",
opts: []RAGOption{
WithAugmentationTemplate("Use the following context to answer:\n{{.Context}}\n\nQuestion: {{.Query}}\nAnswer:"),
},
query: &agent.Message{
Message: &pb.Message{Payload: "What is AI?"},
},
documents: &agent.Message{
Message: &pb.Message{Payload: "AI is artificial intelligence"},
},
wantType: "rag_augmented",
checkFunc: func(t *testing.T, result *agent.Message) {
want := "Use the following context to answer:\nAI is artificial intelligence\n\nQuestion: What is AI?\nAnswer:"
if result.Payload != want {
t.Errorf("payload mismatch\n got: %q\n want: %q", result.Payload, want)
}
},
},
{
name: "template strategy without template configured returns error",
opts: []RAGOption{WithAugmentationStrategy(AugmentTemplate)},
query: &agent.Message{
Message: &pb.Message{Payload: "q"},
},
documents: &agent.Message{
Message: &pb.Message{Payload: "d"},
},
wantErr: true,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := augmentInput(tt.query, tt.documents)
r := NewRAG("test", NewMockRuntime(), "retriever", "generator", tt.opts...)
result, err := r.augmentInput(tt.query, tt.documents)
if (err != nil) != tt.wantErr {
t.Fatalf("augmentInput() err = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr {
return
}
if tt.wantType != "" && result != nil && result.Type != tt.wantType {
t.Errorf("Type = %s, want %s", result.Type, tt.wantType)
}
Expand Down
Loading