Skip to content

Commit abb22cf

Browse files
committed
fix gpu off error and add GetEmbedding api
1 parent 706caa6 commit abb22cf

File tree

3 files changed

+78
-5
lines changed

3 files changed

+78
-5
lines changed

embed_windows.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,23 @@ func getDl(gpu bool) []byte {
3737
log.Println(err)
3838
}
3939
if supportGpuTable[gpuInfo] != nil {
40-
return supportGpuTable[gpu]
41-
} else {
42-
log.Println("GPU not support, use CPU instead.")
40+
return supportGpuTable[gpuInfo]
4341
}
42+
log.Println("GPU not support, use CPU instead.")
4443
}
4544
if cpu.X86.HasAVX512 {
45+
log.Println("Use CPU AVX512 instead.")
4646
return libRwkvAvx512
4747
}
4848
if cpu.X86.HasAVX2 {
49+
log.Println("Use CPU AVX2 instead.")
4950
return libRwkvAvx2
5051
}
5152
if cpu.X86.HasAVX {
53+
log.Println("Use CPU AVX instead.")
5254
return libRwkvAvx
5355
}
54-
//return libRwkvHipLAS
56+
5557
panic("Automatic loading of dynamic library failed, please use `NewRwkvModel` method load manually. ")
5658
return nil
5759
}

rwkv.go

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ func NewRwkvModel(dylibPath string, options RwkvOptions) (*RwkvModel, error) {
7777
}
7878

7979
if options.GpuEnable {
80-
log.Printf("You are about to offload your model to the GPU. " +
80+
log.Printf("If you want to try offload your model to the GPU. " +
8181
"Please confirm the size of your GPU memory to prevent memory overflow." +
8282
"If the model is larger than GPU memory, please specify the layers to offload.")
8383
}
@@ -181,6 +181,29 @@ func (s *RwkvState) Predict(input string) (string, error) {
181181
return s.generateResponse(nil)
182182
}
183183

184+
// GetEmbedding give the model embedding.
185+
// the embedding in rwkv is hidden state the len is n_emb*5*n_layer=46080.
186+
// So if distillation is true, we split len to n_emb = 768
187+
func (s *RwkvState) GetEmbedding(input string, distill bool) ([]float32, error) {
188+
encode, err := s.rwkvModel.tokenizer.Encode(input)
189+
190+
for _, token := range encode {
191+
err = s.rwkvModel.cRwkv.RwkvEval(s.rwkvModel.ctx, uint32(token), s.state, s.state, s.logits)
192+
if err != nil {
193+
return nil, err
194+
}
195+
}
196+
// we should keep state clean
197+
nState := s.rwkvModel.cRwkv.RwkvGetStateLength(s.rwkvModel.ctx)
198+
if distill {
199+
nState = s.rwkvModel.cRwkv.RwkvGetNEmbedding(s.rwkvModel.ctx)
200+
}
201+
emb := make([]float32, nState)
202+
copy(emb, s.state)
203+
s.state = make([]float32, nState)
204+
return emb, nil
205+
}
206+
184207
func (s *RwkvState) PredictStream(input string, output chan string) {
185208
go func() {
186209
err := s.handelInput(input)

rwkv_test.go

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,51 @@ func TestChat(t *testing.T) {
277277
})
278278

279279
}
280+
281+
func TestRwkvState_GetEmbedding(t *testing.T) {
282+
rwkv, err := NewRwkvAutoModel(RwkvOptions{
283+
MaxTokens: 500,
284+
TokenizerType: Normal, //or World
285+
PrintError: true,
286+
CpuThreads: 2,
287+
GpuEnable: true,
288+
})
289+
defer rwkv.Close()
290+
291+
err = rwkv.LoadFromFile("./models/RWKV-4b-Pile-171M-20230202-7922-f16.bin")
292+
if err != nil {
293+
t.Error(err)
294+
return
295+
}
296+
ctx, err := rwkv.InitState()
297+
if err != nil {
298+
t.Error(err)
299+
return
300+
}
301+
nb := ctx.rwkvModel.cRwkv.RwkvGetNEmbedding(ctx.rwkvModel.ctx)
302+
t.Log(nb)
303+
nl := ctx.rwkvModel.cRwkv.RwkvGetNLayer(ctx.rwkvModel.ctx)
304+
305+
t.Run("hidden state", func(t *testing.T) {
306+
embedding, err := ctx.GetEmbedding("hello word", false)
307+
if err != nil {
308+
t.Error(err)
309+
return
310+
}
311+
t.Log(embedding)
312+
t.Log(len(embedding))
313+
t.Log(nl)
314+
assert(t, len(embedding) == int(nb*5*nl))
315+
})
316+
317+
t.Run("distill hidden state ", func(t *testing.T) {
318+
embedding, err := ctx.GetEmbedding("hello word", true)
319+
if err != nil {
320+
t.Error(err)
321+
return
322+
}
323+
t.Log(embedding)
324+
assert(t, len(embedding) == int(nb))
325+
})
326+
327+
}

0 commit comments

Comments
 (0)