-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathembedding.py
More file actions
90 lines (73 loc) · 2.75 KB
/
embedding.py
File metadata and controls
90 lines (73 loc) · 2.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import normalize
import logging
from config import settings
logger = logging.getLogger(__name__)
if settings.EMBEDDING_DEVICE == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
elif settings.EMBEDDING_DEVICE == "cuda" and torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
logger.info(f"Sử dụng device: {device} (config: {settings.EMBEDDING_DEVICE})")
try:
model = SentenceTransformer(settings.EMBEDDING_MODEL)
if device == "cuda":
model = model.to(device)
logger.info(f"Model {settings.EMBEDDING_MODEL} đã được load lên GPU")
else:
logger.info(f"Model {settings.EMBEDDING_MODEL} đang chạy trên CPU")
except Exception as e:
logger.error(f"Lỗi khi load model: {e}")
raise
def embedding(texts, batch_size=None):
if batch_size is None:
batch_size = settings.EMBEDDING_BATCH_SIZE
try:
if not texts:
logger.warning("Empty texts list provided to embedding function")
return None
logger.debug(f"Creating embeddings for {len(texts)} texts with batch_size={batch_size}")
embeddings = model.encode(
texts,
batch_size=batch_size,
show_progress_bar=True,
convert_to_tensor=True,
device=device
)
if isinstance(embeddings, torch.Tensor):
embeddings = embeddings.cpu().numpy()
normalized = normalize(embeddings)
logger.debug(f"Embeddings created successfully, shape: {normalized.shape}")
return normalized
except Exception as e:
logger.error(f"Lỗi trong quá trình embedding: {e}")
if device == "cuda":
logger.warning("GPU embedding failed, fallback to CPU...")
try:
model_cpu = SentenceTransformer(settings.EMBEDDING_MODEL)
embeddings = model_cpu.encode(texts, batch_size=batch_size, show_progress_bar=True)
return normalize(embeddings)
except Exception as cpu_error:
logger.error(f"CPU fallback also failed: {cpu_error}")
raise
else:
raise
def get_embedding_model():
return model
def get_device_info():
if torch.cuda.is_available():
return {
"device": "cuda",
"gpu_name": torch.cuda.get_device_name(0),
"gpu_memory": f"{torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB",
"current_device": device,
"model": settings.EMBEDDING_MODEL
}
else:
return {
"device": "cpu",
"current_device": device,
"model": settings.EMBEDDING_MODEL
}