@@ -518,9 +518,21 @@ def __init__(
518
518
"API key must be provided for self-hosted LLMs. "
519
519
"Please pass it as the keyword argument 'api_key'"
520
520
)
521
+ if kwargs .get ("input_key" ) is None :
522
+ raise ValueError (
523
+ "Input key must be provided for self-hosted LLMs. "
524
+ "Please pass it as the keyword argument 'input_key'"
525
+ )
526
+ if kwargs .get ("output_key" ) is None :
527
+ raise ValueError (
528
+ "Output key must be provided for self-hosted LLMs. "
529
+ "Please pass it as the keyword argument 'output_key'"
530
+ )
521
531
522
532
self .url = kwargs ["url" ]
523
533
self .api_key = kwargs ["api_key" ]
534
+ self .input_key = kwargs ["input_key" ]
535
+ self .output_key = kwargs ["output_key" ]
524
536
self ._initialize_llm ()
525
537
526
538
def _initialize_llm (self ):
@@ -559,8 +571,7 @@ def _make_request(self, llm_input: str) -> Dict[str, Any]:
559
571
"Authorization" : f"Bearer { self .api_key } " ,
560
572
"Content-Type" : "application/json" ,
561
573
}
562
- # TODO: use correct input key
563
- data = {"inputs" : llm_input }
574
+ data = {self .input_key : llm_input }
564
575
response = requests .post (self .url , headers = headers , json = data )
565
576
if response .status_code == 200 :
566
577
response_data = response .json ()[0 ]
@@ -570,9 +581,15 @@ def _make_request(self, llm_input: str) -> Dict[str, Any]:
570
581
571
582
def _get_output (self , response : Dict [str , Any ]) -> str :
572
583
"""Gets the output from the response."""
573
- # TODO: use correct output key
574
- return response ["generated_text" ]
584
+ return response [self .output_key ]
575
585
576
586
def _get_cost_estimate (self , response : Dict [str , Any ]) -> float :
577
587
"""Estimates the cost from the response."""
578
588
return 0
589
+
590
+
591
+ class HuggingFaceModelRunner (SelfHostedLLModelRunner ):
592
+ """Wraps LLMs hosted in HuggingFace."""
593
+
594
+ def __init__ (self , url , api_key ):
595
+ super ().__init__ (url , api_key , input_key = "inputs" , output_key = "generated_text" )
0 commit comments