4444 Bedrock ,
4545 Cohere ,
4646 GPT4All ,
47- HuggingFaceHub ,
47+ HuggingFaceEndpoint ,
4848 OpenAI ,
4949 SagemakerEndpoint ,
5050 Together ,
@@ -318,7 +318,6 @@ def __init__(self, *args, **kwargs):
318318 ),
319319 "text" : PromptTemplate .from_template ("{prompt}" ), # No customization
320320 }
321-
322321 super ().__init__ (* args , ** kwargs , ** model_kwargs )
323322
324323 async def _call_in_executor (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
@@ -582,14 +581,10 @@ def allows_concurrency(self):
582581 return False
583582
584583
585- HUGGINGFACE_HUB_VALID_TASKS = (
586- "text2text-generation" ,
587- "text-generation" ,
588- "text-to-image" ,
589- )
590-
591-
592- class HfHubProvider (BaseProvider , HuggingFaceHub ):
584+ # References for using HuggingFaceEndpoint and InferenceClient:
585+ # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient
586+ # https://github.com/langchain-ai/langchain/blob/master/libs/community/langchain_community/llms/huggingface_endpoint.py
587+ class HfHubProvider (BaseProvider , HuggingFaceEndpoint ):
593588 id = "huggingface_hub"
594589 name = "Hugging Face Hub"
595590 models = ["*" ]
@@ -609,33 +604,35 @@ class HfHubProvider(BaseProvider, HuggingFaceHub):
609604 @root_validator ()
610605 def validate_environment (cls , values : Dict ) -> Dict :
611606 """Validate that api key and python package exists in environment."""
612- huggingfacehub_api_token = get_from_dict_or_env (
613- values , "huggingfacehub_api_token" , "HUGGINGFACEHUB_API_TOKEN"
614- )
615607 try :
616- from huggingface_hub .inference_api import InferenceApi
608+ huggingfacehub_api_token = get_from_dict_or_env (
609+ values , "huggingfacehub_api_token" , "HUGGINGFACEHUB_API_TOKEN"
610+ )
611+ except Exception as e :
612+ raise ValueError (
613+ "Could not authenticate with huggingface_hub. "
614+ "Please check your API token."
615+ ) from e
616+ try :
617+ from huggingface_hub import InferenceClient
617618
618- repo_id = values ["repo_id" ]
619- client = InferenceApi (
620- repo_id = repo_id ,
619+ values ["client" ] = InferenceClient (
620+ model = values [ "model" ],
621+ timeout = values [ "timeout" ] ,
621622 token = huggingfacehub_api_token ,
622- task = values . get ( "task" ) ,
623+ ** values [ "server_kwargs" ] ,
623624 )
624- if client .task not in HUGGINGFACE_HUB_VALID_TASKS :
625- raise ValueError (
626- f"Got invalid task { client .task } , "
627- f"currently only { HUGGINGFACE_HUB_VALID_TASKS } are supported"
628- )
629- values ["client" ] = client
630625 except ImportError :
631626 raise ValueError (
632627 "Could not import huggingface_hub python package. "
633628 "Please install it with `pip install huggingface_hub`."
634629 )
635630 return values
636631
637- # Handle image outputs
638- def _call (self , prompt : str , stop : Optional [List [str ]] = None ) -> str :
632+ # Handle text and image outputs
633+ def _call (
634+ self , prompt : str , stop : Optional [List [str ]] = None , ** kwargs : Any
635+ ) -> str :
639636 """Call out to Hugging Face Hub's inference endpoint.
640637
641638 Args:
@@ -650,45 +647,51 @@ def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
650647
651648 response = hf("Tell me a joke.")
652649 """
653- _model_kwargs = self .model_kwargs or {}
654- response = self .client (inputs = prompt , params = _model_kwargs )
655-
656- if type (response ) is dict and "error" in response :
657- raise ValueError (f"Error raised by inference API: { response ['error' ]} " )
658-
659- # Custom code for responding to image generation responses
660- if self .client .task == "text-to-image" :
661- imageFormat = response .format # Presume it's a PIL ImageFile
662- mimeType = ""
663- if imageFormat == "JPEG" :
664- mimeType = "image/jpeg"
665- elif imageFormat == "PNG" :
666- mimeType = "image/png"
667- elif imageFormat == "GIF" :
668- mimeType = "image/gif"
650+ invocation_params = self ._invocation_params (stop , ** kwargs )
651+ invocation_params ["stop" ] = invocation_params [
652+ "stop_sequences"
653+ ] # porting 'stop_sequences' into the 'stop' argument
654+ response = self .client .post (
655+ json = {"inputs" : prompt , "parameters" : invocation_params },
656+ stream = False ,
657+ task = self .task ,
658+ )
659+
660+ try :
661+ if "generated_text" in str (response ):
662+ # text2 text or text-generation task
663+ response_text = json .loads (response .decode ())[0 ]["generated_text" ]
664+ # Maybe the generation has stopped at one of the stop sequences:
665+ # then we remove this stop sequence from the end of the generated text
666+ for stop_seq in invocation_params ["stop_sequences" ]:
667+ if response_text [- len (stop_seq ) :] == stop_seq :
668+ response_text = response_text [: - len (stop_seq )]
669+ return response_text
669670 else :
670- raise ValueError (f"Unrecognized image format { imageFormat } " )
671-
672- buffer = io .BytesIO ()
673- response .save (buffer , format = imageFormat )
674- # Encode image data to Base64 bytes, then decode bytes to str
675- return mimeType + ";base64," + base64 .b64encode (buffer .getvalue ()).decode ()
676-
677- if self .client .task == "text-generation" :
678- # Text generation return includes the starter text.
679- text = response [0 ]["generated_text" ][len (prompt ) :]
680- elif self .client .task == "text2text-generation" :
681- text = response [0 ]["generated_text" ]
682- else :
671+ # text-to-image task
672+ # https://huggingface.co/docs/huggingface_hub/main/en/package_reference/inference_client#huggingface_hub.InferenceClient.text_to_image.example
673+ # Custom code for responding to image generation responses
674+ image = self .client .text_to_image (prompt )
675+ imageFormat = image .format # Presume it's a PIL ImageFile
676+ mimeType = ""
677+ if imageFormat == "JPEG" :
678+ mimeType = "image/jpeg"
679+ elif imageFormat == "PNG" :
680+ mimeType = "image/png"
681+ elif imageFormat == "GIF" :
682+ mimeType = "image/gif"
683+ else :
684+ raise ValueError (f"Unrecognized image format { imageFormat } " )
685+ buffer = io .BytesIO ()
686+ image .save (buffer , format = imageFormat )
687+ # # Encode image data to Base64 bytes, then decode bytes to str
688+ return (
689+ mimeType + ";base64," + base64 .b64encode (buffer .getvalue ()).decode ()
690+ )
691+ except :
683692 raise ValueError (
684- f"Got invalid task { self .client .task } , "
685- f"currently only { HUGGINGFACE_HUB_VALID_TASKS } are supported"
693+ "Task not supported, only text-generation and text-to-image tasks are valid."
686694 )
687- if stop is not None :
688- # This is a bit hacky, but I can't figure out a better way to enforce
689- # stop tokens when making calls to huggingface_hub.
690- text = enforce_stop_tokens (text , stop )
691- return text
692695
693696 async def _acall (self , * args , ** kwargs ) -> Coroutine [Any , Any , str ]:
694697 return await self ._call_in_executor (* args , ** kwargs )
0 commit comments