Skip to content

Commit 2a9f64b

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-4812 Create model runner for self hosted LLMs
1 parent bbdec48 commit 2a9f64b

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

openlayer/model_runners/ll_model_runners.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,9 +518,21 @@ def __init__(
518518
"API key must be provided for self-hosted LLMs. "
519519
"Please pass it as the keyword argument 'api_key'"
520520
)
521+
if kwargs.get("input_key") is None:
522+
raise ValueError(
523+
"Input key must be provided for self-hosted LLMs. "
524+
"Please pass it as the keyword argument 'input_key'"
525+
)
526+
if kwargs.get("output_key") is None:
527+
raise ValueError(
528+
"Output key must be provided for self-hosted LLMs. "
529+
"Please pass it as the keyword argument 'output_key'"
530+
)
521531

522532
self.url = kwargs["url"]
523533
self.api_key = kwargs["api_key"]
534+
self.input_key = kwargs["input_key"]
535+
self.output_key = kwargs["output_key"]
524536
self._initialize_llm()
525537

526538
def _initialize_llm(self):
@@ -559,8 +571,7 @@ def _make_request(self, llm_input: str) -> Dict[str, Any]:
559571
"Authorization": f"Bearer {self.api_key}",
560572
"Content-Type": "application/json",
561573
}
562-
# TODO: use correct input key
563-
data = {"inputs": llm_input}
574+
data = {self.input_key: llm_input}
564575
response = requests.post(self.url, headers=headers, json=data)
565576
if response.status_code == 200:
566577
response_data = response.json()[0]
@@ -570,9 +581,15 @@ def _make_request(self, llm_input: str) -> Dict[str, Any]:
570581

571582
def _get_output(self, response: Dict[str, Any]) -> str:
572583
"""Gets the output from the response."""
573-
# TODO: use correct output key
574-
return response["generated_text"]
584+
return response[self.output_key]
575585

576586
def _get_cost_estimate(self, response: Dict[str, Any]) -> float:
577587
"""Estimates the cost from the response."""
578588
return 0
589+
590+
591+
class HuggingFaceModelRunner(SelfHostedLLModelRunner):
592+
"""Wraps LLMs hosted in HuggingFace."""
593+
594+
def __init__(self, url, api_key):
595+
super().__init__(url, api_key, input_key="inputs", output_key="generated_text")

openlayer/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ class ModelRunnerFactory:
110110
"Cohere": ll_model_runners.CohereGenerateModelRunner,
111111
"OpenAI": ll_model_runners.OpenAIChatCompletionRunner,
112112
"SelfHosted": ll_model_runners.SelfHostedLLModelRunner,
113+
"HuggingFace": ll_model_runners.HuggingFaceModelRunner,
113114
}
114115
_MODEL_RUNNERS = {
115116
tasks.TaskType.TabularClassification.value: traditional_ml_model_runners.ClassificationModelRunner,

0 commit comments

Comments
 (0)