@@ -3,21 +3,173 @@ import { createChatCompletion } from '../config';
3
3
import { type ChatItemType } from '@fastgpt/global/core/chat/type' ;
4
4
import { countGptMessagesTokens , countPromptTokens } from '../../../common/string/tiktoken/index' ;
5
5
import { chats2GPTMessages } from '@fastgpt/global/core/chat/adapt' ;
6
- import { getLLMModel } from '../model' ;
6
+ import { getLLMModel , getEmbeddingModel } from '../model' ;
7
+ import { getVectorsByText } from '../../ai/embedding' ;
7
8
import { llmCompletionsBodyFormat , formatLLMResponse } from '../utils' ;
8
9
import { addLog } from '../../../common/system/log' ;
9
10
import { filterGPTMessageByMaxContext } from '../../chat/utils' ;
10
11
import json5 from 'json5' ;
11
12
12
13
/*
13
- query extension - 问题扩展
14
- 可以根据上下文,消除指代性问题以及扩展问题,利于检索。
14
+ Query Extension - Semantic Search Enhancement
15
+
16
+ This module can eliminate referential ambiguity and expand queries based on context to improve retrieval.
17
+
18
+ Submodular Optimization Mode: Generate multiple candidate queries, then use submodular algorithm to select the optimal query combination
15
19
*/
16
20
21
+ // Priority Queue implementation for submodular optimization
22
+ class PriorityQueue < T > {
23
+ private heap : Array < { item : T ; priority : number } > = [ ] ;
24
+
25
+ enqueue ( item : T , priority : number ) : void {
26
+ this . heap . push ( { item, priority } ) ;
27
+ this . heap . sort ( ( a , b ) => b . priority - a . priority ) ;
28
+ }
29
+
30
+ dequeue ( ) : T | undefined {
31
+ return this . heap . shift ( ) ?. item ;
32
+ }
33
+
34
+ isEmpty ( ) : boolean {
35
+ return this . heap . length === 0 ;
36
+ }
37
+
38
+ size ( ) : number {
39
+ return this . heap . length ;
40
+ }
41
+ }
42
+
43
+ // Calculate cosine similarity
44
+ function cosineSimilarity ( a : number [ ] , b : number [ ] ) : number {
45
+ if ( a . length !== b . length ) {
46
+ throw new Error ( 'Vectors must have the same length' ) ;
47
+ }
48
+
49
+ let dotProduct = 0 ;
50
+ let normA = 0 ;
51
+ let normB = 0 ;
52
+
53
+ for ( let i = 0 ; i < a . length ; i ++ ) {
54
+ dotProduct += a [ i ] * b [ i ] ;
55
+ normA += a [ i ] * a [ i ] ;
56
+ normB += b [ i ] * b [ i ] ;
57
+ }
58
+
59
+ if ( normA === 0 || normB === 0 ) return 0 ;
60
+ return dotProduct / ( Math . sqrt ( normA ) * Math . sqrt ( normB ) ) ;
61
+ }
62
+
63
+ // Calculate marginal gain
64
+ function computeMarginalGain (
65
+ candidateEmbedding : number [ ] ,
66
+ selectedEmbeddings : number [ ] [ ] ,
67
+ originalEmbedding : number [ ] ,
68
+ alpha : number = 0.3
69
+ ) : number {
70
+ if ( selectedEmbeddings . length === 0 ) {
71
+ return alpha * cosineSimilarity ( originalEmbedding , candidateEmbedding ) ;
72
+ }
73
+
74
+ let maxSimilarity = 0 ;
75
+ for ( const selectedEmbedding of selectedEmbeddings ) {
76
+ const similarity = cosineSimilarity ( candidateEmbedding , selectedEmbedding ) ;
77
+ maxSimilarity = Math . max ( maxSimilarity , similarity ) ;
78
+ }
79
+
80
+ const relevance = alpha * cosineSimilarity ( originalEmbedding , candidateEmbedding ) ;
81
+ const diversity = 1 - maxSimilarity ;
82
+
83
+ return relevance + diversity ;
84
+ }
85
+
86
+ // Lazy greedy query selection algorithm
87
+ function lazyGreedyQuerySelection (
88
+ candidates : string [ ] ,
89
+ embeddings : number [ ] [ ] ,
90
+ originalEmbedding : number [ ] ,
91
+ k : number ,
92
+ alpha : number = 0.3
93
+ ) : string [ ] {
94
+ const n = candidates . length ;
95
+ const selected : string [ ] = [ ] ;
96
+ const selectedEmbeddings : number [ ] [ ] = [ ] ;
97
+
98
+ // Initialize priority queue
99
+ const pq = new PriorityQueue < { index : number ; gain : number } > ( ) ;
100
+
101
+ // Calculate initial marginal gain for all candidates
102
+ for ( let i = 0 ; i < n ; i ++ ) {
103
+ const gain = computeMarginalGain ( embeddings [ i ] , selectedEmbeddings , originalEmbedding , alpha ) ;
104
+ pq . enqueue ( { index : i , gain } , gain ) ;
105
+ }
106
+
107
+ // Greedy selection
108
+ for ( let iteration = 0 ; iteration < k ; iteration ++ ) {
109
+ if ( pq . isEmpty ( ) ) break ;
110
+
111
+ let bestCandidate : { index : number ; gain : number } | undefined ;
112
+
113
+ // Find candidate with maximum marginal gain
114
+ while ( ! pq . isEmpty ( ) ) {
115
+ const candidate = pq . dequeue ( ) ! ;
116
+ const currentGain = computeMarginalGain (
117
+ embeddings [ candidate . index ] ,
118
+ selectedEmbeddings ,
119
+ originalEmbedding ,
120
+ alpha
121
+ ) ;
122
+
123
+ if ( currentGain >= candidate . gain ) {
124
+ bestCandidate = { index : candidate . index , gain : currentGain } ;
125
+ break ;
126
+ } else {
127
+ pq . enqueue ( candidate , currentGain ) ;
128
+ }
129
+ }
130
+
131
+ if ( bestCandidate ) {
132
+ selected . push ( candidates [ bestCandidate . index ] ) ;
133
+ selectedEmbeddings . push ( embeddings [ bestCandidate . index ] ) ;
134
+ }
135
+ }
136
+
137
+ return selected ;
138
+ }
139
+
140
+ // Generate embeddings for input texts
141
+ async function generateEmbeddings ( texts : string [ ] , model : string ) : Promise < number [ ] [ ] > {
142
+ try {
143
+ const vectorModel = getEmbeddingModel ( model ) ;
144
+ const embeddings : number [ ] [ ] = [ ] ;
145
+
146
+ for ( const text of texts ) {
147
+ // Use vector model's createEmbedding method
148
+ const embedding = await getVectorsByText ( {
149
+ model : vectorModel ,
150
+ input : text ,
151
+ type : 'query'
152
+ } ) ;
153
+ embeddings . push ( embedding . vectors [ 0 ] ) ;
154
+ }
155
+
156
+ return embeddings ;
157
+ } catch ( error ) {
158
+ addLog . warn ( 'Failed to generate embeddings' , { error, model } ) ;
159
+ throw error ;
160
+ }
161
+ }
162
+
17
163
const title = global . feConfigs ?. systemTitle || 'FastAI' ;
18
164
const defaultPrompt = `## 你的任务
19
- 你作为一个向量检索助手,你的任务是结合历史记录,从不同角度,为“原问题”生成个不同版本的“检索词”,从而提高向量检索的语义丰富度,提高向量检索的精度。
20
- 生成的问题要求指向对象清晰明确,并与“原问题语言相同”。
165
+ 你作为一个向量检索助手,你的任务是结合历史记录,为"原问题"生成{{count}}个不同版本的"检索词"。这些检索词应该从不同角度探索主题,以提高向量检索的语义丰富度和精度。
166
+
167
+ ## 要求
168
+ 1. 每个检索词必须与原问题相关
169
+ 2. 检索词应该探索不同方面(例如:原因、影响、解决方案、示例、对比等)
170
+ 3. 避免检索词之间的冗余
171
+ 4. 保持检索词简洁且可搜索
172
+ 5. 生成的问题要求指向对象清晰明确,并与"原问题语言相同"
21
173
22
174
## 参考示例
23
175
@@ -26,15 +178,15 @@ const defaultPrompt = `## 你的任务
26
178
null
27
179
"""
28
180
原问题: 介绍下剧情。
29
- 检索词: ["介绍下故事的背景。","故事的主题是什么?","介绍下故事的主要人物。"]
181
+ 检索词: ["介绍下故事的背景。","故事的主题是什么?","介绍下故事的主要人物。","故事的转折点在哪里?","故事的结局如何?" ]
30
182
----------------
31
183
历史记录:
32
184
"""
33
185
user: 对话背景。
34
186
assistant: 当前对话是关于 Nginx 的介绍和使用等。
35
187
"""
36
188
原问题: 怎么下载
37
- 检索词: ["Nginx 如何下载?","下载 Nginx 需要什么条件?","有哪些渠道可以下载 Nginx?"]
189
+ 检索词: ["Nginx 如何下载?","下载 Nginx 需要什么条件?","有哪些渠道可以下载 Nginx?","Nginx 各版本的下载方式有什么区别?","如何选择合适的 Nginx 版本下载?" ]
38
190
----------------
39
191
历史记录:
40
192
"""
@@ -44,23 +196,23 @@ user: 报错 "no connection"
44
196
assistant: 报错"no connection"可能是因为……
45
197
"""
46
198
原问题: 怎么解决
47
- 检索词: ["Nginx报错" no connection" 如何解决?","造成'no connection'报错的原因。","Nginx提示'no connection',要怎么办?"]
199
+ 检索词: ["Nginx报错' no connection' 如何解决?","造成'no connection'报错的原因。","Nginx提示'no connection',要怎么办?","'no connection'错误的常见解决步骤。","如何预防 Nginx 'no connection' 错误 ?"]
48
200
----------------
49
201
历史记录:
50
202
"""
51
203
user: How long is the maternity leave?
52
204
assistant: The number of days of maternity leave depends on the city in which the employee is located. Please provide your city so that I can answer your questions.
53
205
"""
54
206
原问题: ShenYang
55
- 检索词: ["How many days is maternity leave in Shenyang?","Shenyang's maternity leave policy.","The standard of maternity leave in Shenyang."]
207
+ 检索词: ["How many days is maternity leave in Shenyang?","Shenyang's maternity leave policy.","The standard of maternity leave in Shenyang.","What benefits are included in Shenyang's maternity leave?","How to apply for maternity leave in Shenyang?" ]
56
208
----------------
57
209
历史记录:
58
210
"""
59
211
user: 作者是谁?
60
212
assistant: ${ title } 的作者是 labring。
61
213
"""
62
214
原问题: Tell me about him
63
- 检索词: ["Introduce labring, the author of ${ title } ." ," Background information on author labring." "," Why does labring do ${ title } ?"]
215
+ 检索词: ["Introduce labring, the author of ${ title } ." ,"Background information on author labring."," Why does labring do ${ title } ?","What other projects has labring worked on?","How did labring start ${ title } ?"]
64
216
----------------
65
217
历史记录:
66
218
"""
@@ -76,7 +228,7 @@ user: ${title} 如何收费?
76
228
assistant: ${ title } 收费可以参考……
77
229
"""
78
230
原问题: 你知道 laf 么?
79
- 检索词: ["laf 的官网地址是多少?","laf 的使用教程。","laf 有什么特点和优势。"]
231
+ 检索词: ["laf 的官网地址是多少?","laf 的使用教程。","laf 有什么特点和优势。","laf 的主要功能是什么?","laf 与其他类似产品的对比。" ]
80
232
----------------
81
233
历史记录:
82
234
"""
@@ -102,6 +254,7 @@ assistant: Laf 是一个云函数开发平台。
102
254
103
255
1. 输出格式为 JSON 数组,数组中每个元素为字符串。无需对输出进行任何解释。
104
256
2. 输出语言与原问题相同。原问题为中文则输出中文;原问题为英文则输出英文。
257
+ 3. 确保生成恰好 {{count}} 个检索词。
105
258
106
259
## 开始任务
107
260
@@ -116,12 +269,14 @@ export const queryExtension = async ({
116
269
chatBg,
117
270
query,
118
271
histories = [ ] ,
119
- model
272
+ model,
273
+ generateCount = 10 // 添加生成数量参数,默认为10个
120
274
} : {
121
275
chatBg ?: string ;
122
276
query : string ;
123
277
histories : ChatItemType [ ] ;
124
278
model : string ;
279
+ generateCount ?: number ;
125
280
} ) : Promise < {
126
281
rawQuery : string ;
127
282
extensionQueries : string [ ] ;
@@ -162,7 +317,8 @@ assistant: ${chatBg}
162
317
role : 'user' ,
163
318
content : replaceVariable ( defaultPrompt , {
164
319
query : `${ query } ` ,
165
- histories : concatFewShot || 'null'
320
+ histories : concatFewShot || 'null' ,
321
+ count : generateCount . toString ( )
166
322
} )
167
323
}
168
324
] as any ;
@@ -216,15 +372,40 @@ assistant: ${chatBg}
216
372
try {
217
373
const queries = json5 . parse ( jsonStr ) as string [ ] ;
218
374
375
+ if ( ! Array . isArray ( queries ) || queries . length === 0 ) {
376
+ return {
377
+ rawQuery : query ,
378
+ extensionQueries : [ ] ,
379
+ model,
380
+ inputTokens,
381
+ outputTokens
382
+ } ;
383
+ }
384
+
385
+ // Generate embeddings for original query and candidate queries
386
+ const allQueries = [ query , ...queries ] ;
387
+ const embeddings = await generateEmbeddings ( allQueries , model ) ;
388
+ const originalEmbedding = embeddings [ 0 ] ;
389
+ const candidateEmbeddings = embeddings . slice ( 1 ) ;
390
+ // Select optimal queries using lazy greedy algorithm
391
+ const selectedQueries = lazyGreedyQuerySelection (
392
+ queries ,
393
+ candidateEmbeddings ,
394
+ originalEmbedding ,
395
+ Math . min ( 5 , queries . length ) , // Select top 5 queries or less
396
+ 0.3 // alpha parameter for balancing relevance and diversity
397
+ ) ;
398
+
219
399
return {
220
400
rawQuery : query ,
221
- extensionQueries : ( Array . isArray ( queries ) ? queries : [ ] ) . slice ( 0 , 5 ) ,
401
+ extensionQueries : selectedQueries ,
222
402
model,
223
403
inputTokens,
224
404
outputTokens
225
405
} ;
226
406
} catch ( error ) {
227
- addLog . warn ( 'Query extension failed, not a valid JSON' , {
407
+ addLog . warn ( 'Query extension failed' , {
408
+ error,
228
409
answer
229
410
} ) ;
230
411
return {
0 commit comments