-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrelationship_training.py
More file actions
70 lines (58 loc) · 2.83 KB
/
relationship_training.py
File metadata and controls
70 lines (58 loc) · 2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import math
import logging
import itertools
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer
import ydf # Yggdrasil Decision Forests
# ----- Configuration -----
STELLA_MODEL_DIR = "stella_en_400M_v5"
GBDT_MODEL_DIR = "gbdt"
DATASET_CSV_PATH = "datasets_acordar.csv"
RELATIONSHIPS_CSV_PATH = "relationships.csv" # NEW: your labeled pairs
EMBEDDINGS_OUT = "dataset_acordar_embeddings.npy"
IDS_OUT = "dataset_acordar_ids.npy"
# -------------------------
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s")
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")
# Step 1: generate or load embeddings (unchanged)…
datasets_df = pd.read_csv(DATASET_CSV_PATH, encoding="utf-8", na_filter=False)
dataset_ids = datasets_df['dataset_id'].tolist()
titles = datasets_df['title'].tolist()
descriptions = datasets_df['description'].tolist()
N = len(dataset_ids)
texts = ["[CLS]" + t.strip() + "[SEP]" + d.strip() for t,d in zip(titles, descriptions)]
model = SentenceTransformer(STELLA_MODEL_DIR, trust_remote_code=True).to(device)
EMBED_BATCH_SIZE = 128 # Batch size for embedding computation (adjust based on GPU memory)
embeddings_list = []
for i in range(0, N, EMBED_BATCH_SIZE):
batch = texts[i : i + EMBED_BATCH_SIZE]
embeds = model.encode(batch, batch_size=len(batch), show_progress_bar=False)
embeddings_list.append(embeds)
embeddings_matrix = np.vstack(embeddings_list)
np.save(EMBEDDINGS_OUT, embeddings_matrix)
np.save(IDS_OUT, np.array(dataset_ids))
id_to_index = {ds_id: idx for idx, ds_id in enumerate(dataset_ids)}
# Step 2: Load labeled relationships, build features, train & save GBDT
logging.info("Loading labeled pairs for training...")
rels_df = pd.read_csv(RELATIONSHIPS_CSV_PATH, encoding="utf-8", na_filter=False)
# Map to embedding indices
idx1 = rels_df['dataset_id1'].map(id_to_index)
idx2 = rels_df['dataset_id2'].map(id_to_index)
logging.info("Building feature matrix for training...")
emb1 = embeddings_matrix[idx1]
emb2 = embeddings_matrix[idx2]
features = np.hstack((emb1, emb2))
feature_cols = [str(i) for i in range(features.shape[1])]
train_df = pd.DataFrame(features, columns=feature_cols)
train_df['relationship'] = rels_df['relationship'] # your binary (0/1) or multiclass labels
logging.info("Training GBDT model on pairwise features...")
learner = ydf.GradientBoostedTreesLearner(label="relationship") # :contentReference[oaicite:0]{index=0}
gbdt_model = learner.train(train_df)
logging.info(f"Saving trained GBDT model to '{GBDT_MODEL_DIR}'...")
gbdt_model.save(GBDT_MODEL_DIR) # :contentReference[oaicite:1]{index=1}
logging.info("Model training and save complete.")