Skip to content

Commit 8a9bbf3

Browse files
dcc123456yingfeng
andauthored
Feat: add memory function by go (#13754)
### What problem does this PR solve? Feat: Add Memory function by go ### Type of change - [x] New Feature (non-breaking change which adds functionality) --------- Co-authored-by: Yingfeng <yingfeng.zhang@gmail.com>
1 parent 406339a commit 8a9bbf3

File tree

11 files changed

+2350
-62
lines changed

11 files changed

+2350
-62
lines changed

cmd/server_main.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ func startServer(config *server.Config) {
175175
connectorService := service.NewConnectorService()
176176
searchService := service.NewSearchService()
177177
fileService := service.NewFileService()
178+
memoryService := service.NewMemoryService()
178179

179180
// Initialize handler layer
180181
authHandler := handler.NewAuthHandler()
@@ -191,9 +192,10 @@ func startServer(config *server.Config) {
191192
connectorHandler := handler.NewConnectorHandler(connectorService, userService)
192193
searchHandler := handler.NewSearchHandler(searchService, userService)
193194
fileHandler := handler.NewFileHandler(fileService, userService)
195+
memoryHandler := handler.NewMemoryHandler(memoryService)
194196

195197
// Initialize router
196-
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler)
198+
r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler)
197199

198200
// Create Gin engine
199201
ginEngine := gin.New()

internal/dao/memory.go

Lines changed: 370 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,370 @@
1+
//
2+
// Copyright 2026 The InfiniFlow Authors. All Rights Reserved.
3+
//
4+
// Licensed under the Apache License, Version 2.0 (the "License");
5+
// you may not use this file except in compliance with the License.
6+
// You may obtain a copy of the License at
7+
//
8+
// http://www.apache.org/licenses/LICENSE-2.0
9+
//
10+
// Unless required by applicable law or agreed to in writing, software
11+
// distributed under the License is distributed on an "AS IS" BASIS,
12+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
// See the License for the specific language governing permissions and
14+
// limitations under the License.
15+
//
16+
17+
// Package dao implements the data access layer
18+
// This file implements Memory-related database operations
19+
// Consistent with Python memory_service.py
20+
package dao
21+
22+
import (
23+
"fmt"
24+
"strings"
25+
26+
"ragflow/internal/model"
27+
)
28+
29+
// Memory type bit flag constants, consistent with Python MemoryType enum
30+
const (
31+
MemoryTypeRaw = 0b0001 // Raw memory (binary: 0001)
32+
MemoryTypeSemantic = 0b0010 // Semantic memory (binary: 0010)
33+
MemoryTypeEpisodic = 0b0100 // Episodic memory (binary: 0100)
34+
MemoryTypeProcedural = 0b1000 // Procedural memory (binary: 1000)
35+
)
36+
37+
// MemoryTypeMap maps memory type names to bit flags
38+
// Exported for use by service package
39+
var MemoryTypeMap = map[string]int{
40+
"raw": MemoryTypeRaw,
41+
"semantic": MemoryTypeSemantic,
42+
"episodic": MemoryTypeEpisodic,
43+
"procedural": MemoryTypeProcedural,
44+
}
45+
46+
// CalculateMemoryType converts memory type names array to bit flags integer
47+
//
48+
// Parameters:
49+
// - memoryTypeNames: Memory type names array
50+
//
51+
// Returns:
52+
// - int64: Bit flags integer
53+
//
54+
// Example:
55+
//
56+
// CalculateMemoryType([]string{"raw", "semantic"}) returns 3 (0b0011)
57+
func CalculateMemoryType(memoryTypeNames []string) int64 {
58+
memoryType := 0
59+
for _, name := range memoryTypeNames {
60+
lowerName := strings.ToLower(name)
61+
if mt, ok := MemoryTypeMap[lowerName]; ok {
62+
memoryType |= mt
63+
}
64+
}
65+
return int64(memoryType)
66+
}
67+
68+
// GetMemoryTypeHuman converts memory type bit flags to human-readable names
69+
//
70+
// Parameters:
71+
// - memoryType: Bit flags integer representing memory types
72+
//
73+
// Returns:
74+
// - []string: Array of human-readable memory type names
75+
//
76+
// Example:
77+
//
78+
// GetMemoryTypeHuman(3) returns ["raw", "semantic"]
79+
func GetMemoryTypeHuman(memoryType int64) []string {
80+
var result []string
81+
if memoryType&int64(MemoryTypeRaw) != 0 {
82+
result = append(result, "raw")
83+
}
84+
if memoryType&int64(MemoryTypeSemantic) != 0 {
85+
result = append(result, "semantic")
86+
}
87+
if memoryType&int64(MemoryTypeEpisodic) != 0 {
88+
result = append(result, "episodic")
89+
}
90+
if memoryType&int64(MemoryTypeProcedural) != 0 {
91+
result = append(result, "procedural")
92+
}
93+
return result
94+
}
95+
96+
// MemoryDAO handles all Memory-related database operations
97+
type MemoryDAO struct{}
98+
99+
// NewMemoryDAO creates a new MemoryDAO instance
100+
//
101+
// Returns:
102+
// - *MemoryDAO: Initialized DAO instance
103+
func NewMemoryDAO() *MemoryDAO {
104+
return &MemoryDAO{}
105+
}
106+
107+
// Create inserts a new memory record into the database
108+
//
109+
// Parameters:
110+
// - memory: Memory model pointer
111+
//
112+
// Returns:
113+
// - error: Database operation error
114+
func (dao *MemoryDAO) Create(memory *model.Memory) error {
115+
return DB.Create(memory).Error
116+
}
117+
118+
// GetByID retrieves a memory record by ID from database
119+
//
120+
// Parameters:
121+
// - id: Memory ID
122+
//
123+
// Returns:
124+
// - *model.Memory: Memory model pointer
125+
// - error: Database operation error
126+
func (dao *MemoryDAO) GetByID(id string) (*model.Memory, error) {
127+
var memory model.Memory
128+
err := DB.Where("id = ?", id).First(&memory).Error
129+
if err != nil {
130+
return nil, err
131+
}
132+
return &memory, nil
133+
}
134+
135+
// GetByTenantID retrieves all memories for a tenant
136+
//
137+
// Parameters:
138+
// - tenantID: Tenant ID
139+
//
140+
// Returns:
141+
// - []*model.Memory: Memory model pointer array
142+
// - error: Database operation error
143+
func (dao *MemoryDAO) GetByTenantID(tenantID string) ([]*model.Memory, error) {
144+
var memories []*model.Memory
145+
err := DB.Where("tenant_id = ?", tenantID).Find(&memories).Error
146+
return memories, err
147+
}
148+
149+
// GetByNameAndTenant checks if memory exists by name and tenant ID
150+
// Used for duplicate name deduplication
151+
//
152+
// Parameters:
153+
// - name: Memory name
154+
// - tenantID: Tenant ID
155+
//
156+
// Returns:
157+
// - []*model.Memory: Matching memory list (for existence check)
158+
// - error: Database operation error
159+
func (dao *MemoryDAO) GetByNameAndTenant(name string, tenantID string) ([]*model.Memory, error) {
160+
var memories []*model.Memory
161+
err := DB.Where("name = ? AND tenant_id = ?", name, tenantID).Find(&memories).Error
162+
return memories, err
163+
}
164+
165+
// GetByIDs retrieves memories by multiple IDs
166+
//
167+
// Parameters:
168+
// - ids: Memory ID list
169+
//
170+
// Returns:
171+
// - []*model.Memory: Memory model pointer array
172+
// - error: Database operation error
173+
func (dao *MemoryDAO) GetByIDs(ids []string) ([]*model.Memory, error) {
174+
var memories []*model.Memory
175+
err := DB.Where("id IN ?", ids).Find(&memories).Error
176+
return memories, err
177+
}
178+
179+
// UpdateByID updates a memory by ID
180+
// Supports partial updates - only updates passed fields
181+
// Automatically handles field type conversions
182+
//
183+
// Parameters:
184+
// - id: Memory ID
185+
// - updates: Fields to update map
186+
//
187+
// Returns:
188+
// - error: Database operation error
189+
//
190+
// Field type handling:
191+
// - memory_type: []string converts to bit flags integer
192+
// - temperature: string converts to float64
193+
// - name: Uses string value directly
194+
// - permissions, forgetting_policy: Uses string value directly
195+
//
196+
// Example:
197+
//
198+
// updates := map[string]interface{}{"name": "NewName", "memory_type": []string{"semantic"}}
199+
// err := dao.UpdateByID("memory123", updates)
200+
func (dao *MemoryDAO) UpdateByID(id string, updates map[string]interface{}) error {
201+
if updates == nil || len(updates) == 0 {
202+
return nil
203+
}
204+
205+
for key, value := range updates {
206+
switch key {
207+
case "memory_type":
208+
if types, ok := value.([]string); ok {
209+
updates[key] = CalculateMemoryType(types)
210+
}
211+
case "temperature":
212+
if tempStr, ok := value.(string); ok {
213+
var temp float64
214+
fmt.Sscanf(tempStr, "%f", &temp)
215+
updates[key] = temp
216+
}
217+
}
218+
}
219+
220+
return DB.Model(&model.Memory{}).Where("id = ?", id).Updates(updates).Error
221+
}
222+
223+
// DeleteByID deletes a memory by ID
224+
//
225+
// Parameters:
226+
// - id: Memory ID
227+
//
228+
// Returns:
229+
// - error: Database operation error
230+
//
231+
// Example:
232+
//
233+
// err := dao.DeleteByID("memory123")
234+
func (dao *MemoryDAO) DeleteByID(id string) error {
235+
return DB.Where("id = ?", id).Delete(&model.Memory{}).Error
236+
}
237+
238+
// GetWithOwnerNameByID retrieves a memory with owner name by ID
239+
// Joins with User table to get owner's nickname
240+
//
241+
// Parameters:
242+
// - id: Memory ID
243+
//
244+
// Returns:
245+
// - *model.MemoryListItem: Memory detail with owner name populated
246+
// - error: Database operation error
247+
//
248+
// Example:
249+
//
250+
// memory, err := dao.GetWithOwnerNameByID("memory123")
251+
func (dao *MemoryDAO) GetWithOwnerNameByID(id string) (*model.MemoryListItem, error) {
252+
querySQL := `
253+
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
254+
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
255+
m.permissions, m.description, m.memory_size, m.forgetting_policy,
256+
m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date,
257+
m.update_time, m.update_date,
258+
u.nickname as owner_name
259+
FROM memory m
260+
LEFT JOIN user u ON m.tenant_id = u.id
261+
WHERE m.id = ?
262+
`
263+
264+
var rawResult struct {
265+
model.Memory
266+
OwnerName *string `gorm:"column:owner_name"`
267+
}
268+
269+
if err := DB.Raw(querySQL, id).Scan(&rawResult).Error; err != nil {
270+
return nil, err
271+
}
272+
273+
return &model.MemoryListItem{
274+
Memory: rawResult.Memory,
275+
OwnerName: rawResult.OwnerName,
276+
}, nil
277+
}
278+
279+
// GetByFilter retrieves memories with optional filters
280+
// Supports filtering by tenant_id, memory_type, storage_type, and keywords
281+
// Returns paginated results with owner_name from user table JOIN
282+
//
283+
// Parameters:
284+
// - tenantIDs: Array of tenant IDs to filter by (empty means all tenants)
285+
// - memoryTypes: Array of memory type names to filter by (empty means all types)
286+
// - storageType: Storage type to filter by (empty means all types)
287+
// - keywords: Keywords to search in memory names (empty means no keyword filter)
288+
// - page: Page number (1-based)
289+
// - pageSize: Number of items per page
290+
//
291+
// Returns:
292+
// - []*model.MemoryListItem: Memory list items with owner name populated
293+
// - int64: Total count of matching memories
294+
// - error: Database operation error
295+
//
296+
// Example:
297+
//
298+
// memories, total, err := dao.GetByFilter([]string{"tenant1"}, []string{"semantic"}, "table", "test", 1, 10)
299+
func (dao *MemoryDAO) GetByFilter(tenantIDs []string, memoryTypes []string, storageType string, keywords string, page int, pageSize int) ([]*model.MemoryListItem, int64, error) {
300+
var conditions []string
301+
var args []interface{}
302+
303+
if len(tenantIDs) > 0 {
304+
conditions = append(conditions, "m.tenant_id IN ?")
305+
args = append(args, tenantIDs)
306+
}
307+
308+
if len(memoryTypes) > 0 {
309+
memoryTypeInt := CalculateMemoryType(memoryTypes)
310+
conditions = append(conditions, "m.memory_type & ? > 0")
311+
args = append(args, memoryTypeInt)
312+
}
313+
314+
if storageType != "" {
315+
conditions = append(conditions, "m.storage_type = ?")
316+
args = append(args, storageType)
317+
}
318+
319+
if keywords != "" {
320+
conditions = append(conditions, "m.name LIKE ?")
321+
args = append(args, "%"+keywords+"%")
322+
}
323+
324+
whereClause := ""
325+
if len(conditions) > 0 {
326+
whereClause = "WHERE " + strings.Join(conditions, " AND ")
327+
}
328+
329+
countSQL := fmt.Sprintf("SELECT COUNT(*) FROM memory m %s", whereClause)
330+
var total int64
331+
if err := DB.Raw(countSQL, args...).Scan(&total).Error; err != nil {
332+
return nil, 0, err
333+
}
334+
335+
offset := (page - 1) * pageSize
336+
querySQL := fmt.Sprintf(`
337+
SELECT m.id, m.name, m.avatar, m.tenant_id, m.memory_type,
338+
m.storage_type, m.embd_id, m.tenant_embd_id, m.llm_id, m.tenant_llm_id,
339+
m.permissions, m.description, m.memory_size, m.forgetting_policy,
340+
m.temperature, m.system_prompt, m.user_prompt, m.create_time, m.create_date,
341+
m.update_time, m.update_date,
342+
u.nickname as owner_name
343+
FROM memory m
344+
LEFT JOIN user u ON m.tenant_id = u.id
345+
%s
346+
ORDER BY m.update_time DESC
347+
LIMIT ? OFFSET ?
348+
`, whereClause)
349+
350+
queryArgs := append(args, pageSize, offset)
351+
352+
var rawResults []struct {
353+
model.Memory
354+
OwnerName *string `gorm:"column:owner_name"`
355+
}
356+
357+
if err := DB.Raw(querySQL, queryArgs...).Scan(&rawResults).Error; err != nil {
358+
return nil, 0, err
359+
}
360+
361+
memories := make([]*model.MemoryListItem, len(rawResults))
362+
for i, r := range rawResults {
363+
memories[i] = &model.MemoryListItem{
364+
Memory: r.Memory,
365+
OwnerName: r.OwnerName,
366+
}
367+
}
368+
369+
return memories, total, nil
370+
}

0 commit comments

Comments
 (0)