Skip to content

Commit 6de9968

Browse files
authored
Merge pull request #17 from codefuse-ai/modelcache_localDB_dev
Modelcache local db dev
2 parents 3369c73 + edecddb commit 6de9968

File tree

4 files changed

+75
-12
lines changed

4 files changed

+75
-12
lines changed

modelcache/embedding/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
huggingface = LazyImport("huggingface", globals(), "modelcache.embedding.huggingface")
44
data2vec = LazyImport("data2vec", globals(), "modelcache.embedding.data2vec")
55
llmEmb = LazyImport("llmEmb", globals(), "modelcache.embedding.llmEmb")
6-
fasttext = LazyImport("fasttext", globals(), "gptcache.embedding.fasttext")
6+
fasttext = LazyImport("fasttext", globals(), "modelcache.embedding.fasttext")
7+
paddlenlp = LazyImport("paddlenlp", globals(), "modelcache.embedding.paddlenlp")
78

89

910
def Huggingface(model="sentence-transformers/all-mpnet-base-v2"):
@@ -20,3 +21,7 @@ def LlmEmb2vecAudio():
2021

2122
def FastText(model="en", dim=None):
2223
return fasttext.FastText(model, dim)
24+
25+
26+
def PaddleNLP(model="ernie-3.0-medium-zh"):
27+
return paddlenlp.PaddleNLP(model)

modelcache/embedding/fasttext.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,4 @@
11
# -*- coding: utf-8 -*-
2-
"""
3-
Alipay.com Inc.
4-
Copyright (c) 2004-2023 All Rights Reserved.
5-
------------------------------------------------------
6-
File Name : fasttext.py
7-
Author : fuhui.phe
8-
Create Time : 2023/12/3 15:40
9-
Description : description what the main function of this file
10-
Change Activity:
11-
version0 : 2023/12/3 15:40 by fuhui.phe init
12-
"""
132
import numpy as np
143
import os
154
from modelcache.utils import import_fasttext

modelcache/embedding/paddlenlp.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# -*- coding: utf-8 -*-
2+
import numpy as np
3+
4+
from modelcache.embedding.base import BaseEmbedding
5+
from modelcache.utils import import_paddlenlp, import_paddle
6+
7+
import_paddle()
8+
import_paddlenlp()
9+
10+
11+
import paddle # pylint: disable=C0413
12+
from paddlenlp.transformers import AutoModel, AutoTokenizer # pylint: disable=C0413
13+
14+
15+
class PaddleNLP(BaseEmbedding):
16+
def __init__(self, model: str = "ernie-3.0-medium-zh"):
17+
self.model = AutoModel.from_pretrained(model)
18+
self.model.eval()
19+
20+
self.tokenizer = AutoTokenizer.from_pretrained(model)
21+
if not self.tokenizer.pad_token:
22+
self.tokenizer.pad_token = "<pad>"
23+
self.__dimension = None
24+
25+
def to_embeddings(self, data, **_):
26+
"""Generate embedding given text input
27+
28+
:param data: text in string.
29+
:type data: str
30+
31+
:return: a text embedding in shape of (dim,).
32+
"""
33+
if not isinstance(data, list):
34+
data = [data]
35+
inputs = self.tokenizer(
36+
data, padding=True, truncation=True, return_tensors="pd"
37+
)
38+
outs = self.model(**inputs)[0]
39+
emb = self.post_proc(outs, inputs).squeeze(0).detach().numpy()
40+
return np.array(emb).astype("float32")
41+
42+
def post_proc(self, token_embeddings, inputs):
43+
attention_mask = paddle.ones(inputs["token_type_ids"].shape)
44+
input_mask_expanded = (
45+
attention_mask.unsqueeze(-1).expand(token_embeddings.shape).astype("float32")
46+
)
47+
sentence_embs = paddle.sum(
48+
token_embeddings * input_mask_expanded, 1
49+
) / paddle.clip(input_mask_expanded.sum(1), min=1e-9)
50+
return sentence_embs
51+
52+
@property
53+
def dimension(self):
54+
"""Embedding dimension.
55+
56+
:return: embedding dimension
57+
"""
58+
if not self.__dimension:
59+
self.__dimension = len(self.to_embeddings("foo"))
60+
return self.__dimension

modelcache/utils/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,3 +52,12 @@ def import_torch():
5252

5353
def import_fasttext():
5454
_check_library("fasttext")
55+
56+
57+
def import_paddle():
58+
prompt_install("protobuf==3.20.0")
59+
_check_library("paddlepaddle")
60+
61+
62+
def import_paddlenlp():
63+
_check_library("paddlenlp")

0 commit comments

Comments
 (0)