Skip to content
Open

Dev #75

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions routellm/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
},
"causal_llm": {"checkpoint_path": "routellm/causal_llm_gpt4_augmented"},
"bert": {"checkpoint_path": "routellm/bert_gpt4_augmented"},
"mf": {"checkpoint_path": "routellm/mf_gpt4_augmented"},
"mf": {"checkpoint_path": "madison-evans/routellm_all-MiniLM-L6-v2"},
}


Expand All @@ -48,7 +48,9 @@ def __init__(
api_base: Optional[str] = None,
api_key: Optional[str] = None,
progress_bar: bool = False,
hf_token: Optional[str] = None, # Add hf_token as a parameter
):
self.hf_token = hf_token # Store the hf_token
self.model_pair = ModelPair(strong=strong_model, weak=weak_model)
self.routers = {}
self.api_base = api_base
Expand All @@ -67,7 +69,8 @@ def __init__(
for router in routers:
if router_pbar is not None:
router_pbar.set_description(f"Loading {router}")
self.routers[router] = ROUTER_CLS[router](**config.get(router, {}))
self.routers[router] = ROUTER_CLS[router](hf_token=self.hf_token, **config.get(router, {}))


# Some Python magic to match the OpenAI Python SDK
self.chat = SimpleNamespace(
Expand Down Expand Up @@ -101,6 +104,14 @@ def _parse_model_name(self, model: str):
f"Invalid model {model}. Model name must be of the format 'router-[router name]-[threshold]."
)
return router, threshold

def get_routed_model(self, messages: list, router: str, threshold: float) -> str:
"""
Get the routed model for a given message using the specified router and threshold.
"""
self._validate_router_threshold(router, threshold)
routed_model = self._get_routed_model_for_completion(messages, router, threshold)
return routed_model

def _get_routed_model_for_completion(
self, messages: list, router: str, threshold: float
Expand Down
88 changes: 64 additions & 24 deletions routellm/routers/matrix_factorization/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import torch
from huggingface_hub import PyTorchModelHubMixin

from transformers import AutoTokenizer, AutoModel
from routellm.routers.similarity_weighted.utils import OPENAI_CLIENT
import logging

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)

MODEL_IDS = {
"RWKV-4-Raven-14B": 0,
Expand Down Expand Up @@ -70,7 +77,6 @@
"zephyr-7b-beta": 63,
}


class MFModel(torch.nn.Module, PyTorchModelHubMixin):
def __init__(
self,
Expand All @@ -79,51 +85,85 @@ def __init__(
text_dim,
num_classes,
use_proj,
use_openai_embeddings=False, # Default: Hugging Face embeddings
embedding_model_name="BAAI/bge-base-en", # Match notebook
hf_token=None, # Hugging Face API token
):
super().__init__()
self._name = "TextMF"
self.use_proj = use_proj
self.P = torch.nn.Embedding(num_models, dim)
self.use_openai_embeddings = use_openai_embeddings
self.hf_token = hf_token
self.embedding_model_name = embedding_model_name

self.embedding_model = "text-embedding-3-small"
# Model embedding matrix
self.P = torch.nn.Embedding(num_models, dim)

if self.use_proj:
self.text_proj = torch.nn.Sequential(
torch.nn.Linear(text_dim, dim, bias=False)
)
self.text_proj = torch.nn.Linear(text_dim, dim, bias=False)
else:
assert (
text_dim == dim
), f"text_dim {text_dim} must be equal to dim {dim} if not using projection"
assert text_dim == dim, f"text_dim {text_dim} must be equal to dim {dim} if not using projection"

self.classifier = torch.nn.Sequential(
torch.nn.Linear(dim, num_classes, bias=False)
)
self.classifier = torch.nn.Linear(dim, num_classes, bias=False)

if not self.use_openai_embeddings:
logger.info(f"Loading Hugging Face tokenizer and model: {self.embedding_model_name}")

# Load tokenizer & model exactly as in the notebook
self.tokenizer = AutoTokenizer.from_pretrained(
self.embedding_model_name,
token=hf_token
)
self.embedding_model = AutoModel.from_pretrained(
self.embedding_model_name,
token=hf_token
)
self.embedding_model.eval() # Set to inference mode
self.embedding_model.to(self.get_device())

def get_device(self):
return self.P.weight.device

def get_prompt_embedding(self, prompt):
"""Generate sentence embedding using mean pooling (matches notebook)."""
logger.info(f"Generating embedding for prompt: {prompt[:30]}...")

inputs = self.tokenizer(
prompt,
padding=True,
truncation=True,
return_tensors="pt"
).to(self.get_device())

with torch.no_grad():
outputs = self.embedding_model(**inputs)
last_hidden_state = outputs.last_hidden_state

# Mean pooling over token embeddings
prompt_embed = last_hidden_state.mean(dim=1).squeeze()

return prompt_embed

def forward(self, model_id, prompt):
model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device())

model_embed = self.P(model_id)
model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1)
prompt_embed = self.get_prompt_embedding(prompt)

prompt_embed = (
OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model)
.data[0]
.embedding
)
prompt_embed = torch.tensor(prompt_embed, device=self.get_device())
prompt_embed = self.text_proj(prompt_embed)
if self.use_proj:
prompt_embed = self.text_proj(prompt_embed)

return self.classifier(model_embed * prompt_embed).squeeze()

@torch.no_grad()
def pred_win_rate(self, model_a, model_b, prompt):
logits = self.forward([model_a, model_b], prompt)
winrate = torch.sigmoid(logits[0] - logits[1]).item()
raw_diff = logits[0] - logits[1]
winrate = torch.sigmoid(raw_diff).item()
logger.info(
f"For prompt: '{prompt[:30]}...', logits: {[float(x) for x in logits]}, "
f"raw difference: {raw_diff:.4f}, winrate (sigmoid): {winrate:.4f}"
)
return winrate

def load(self, path):
self.load_state_dict(torch.load(path))
self.load_state_dict(torch.load(path))
79 changes: 71 additions & 8 deletions routellm/routers/routers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import functools
import random

from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch
from datasets import concatenate_datasets, load_dataset
Expand All @@ -21,6 +21,13 @@
compute_tiers,
preprocess_battles,
)
import logging

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)


def no_parallel(cls):
Expand Down Expand Up @@ -211,34 +218,90 @@ class MatrixFactorizationRouter(Router):
def __init__(
self,
checkpoint_path,
# This is the model pair for scoring at inference time,
# and can be different from the model pair used for routing.
strong_model="gpt-4-1106-preview",
weak_model="mixtral-8x7b-instruct-v0.1",
hidden_size=128,
num_models=64,
text_dim=1536,
num_models=None,
text_dim=None,
num_classes=1,
use_proj=True,
use_openai_embeddings=True,
embedding_model_name=None,
hf_token=None,
):
"""
A simplified constructor that flattens the logic for:
1) Setting num_models from MODEL_IDS,
2) Determining embedding_model_name defaults,
3) Setting text_dim for OpenAI vs. HF embeddings,
4) Initializing the MFModel,
5) Setting strong/weak model IDs.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Default num_models to the length of MODEL_IDS if not provided
num_models = num_models or len(MODEL_IDS)

# Decide which embedding model_name to use if none provided
if not embedding_model_name:
if use_openai_embeddings:
# e.g. "text-embedding-ada-002" or your default
embedding_model_name = "text-embedding-3-small"
else:
embedding_model_name = "BAAI/bge-base-en"

# Decide text_dim if not provided
if text_dim is None:
if use_openai_embeddings:
# e.g., 1536 for text-embedding-ada-002
text_dim = 1536
else:
text_dim = self._infer_hf_text_dim(embedding_model_name)

# Initialize the MFModel
self.model = MFModel.from_pretrained(
checkpoint_path,
dim=hidden_size,
num_models=num_models,
text_dim=text_dim,
num_classes=num_classes,
use_proj=use_proj,
)
self.model = self.model.eval().to(device)
use_openai_embeddings=use_openai_embeddings,
embedding_model_name=embedding_model_name,
hf_token=hf_token,
).eval().to(device)

# Store strong/weak model IDs
self.strong_model_id = MODEL_IDS[strong_model]
self.weak_model_id = MODEL_IDS[weak_model]

@staticmethod
def _infer_hf_text_dim(embedding_model_name: str) -> int:
"""
Helper to load a huggingface model and extract its hidden size.
Immediately frees model from memory.
"""
tokenizer = AutoTokenizer.from_pretrained(embedding_model_name)
hf_model = AutoModel.from_pretrained(embedding_model_name)
dim = hf_model.config.hidden_size

del tokenizer
del hf_model

return dim

def calculate_strong_win_rate(self, prompt):
"""
Scores the prompt using the MF model to see how
often the 'strong' model is predicted to win
over the 'weak' model.
"""
winrate = self.model.pred_win_rate(
self.strong_model_id, self.weak_model_id, prompt
self.strong_model_id,
self.weak_model_id,
prompt
)
logger.info(f"\n\nwinrate: {winrate}\n\n")
return winrate


Expand Down