Skip to content

Commit 389aed2

Browse files
author
xusenlin
committed
Update langchain_llm package
1 parent c0d8a9f commit 389aed2

File tree

4 files changed

+12
-4
lines changed

4 files changed

+12
-4
lines changed

libs/langchain_llm/README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,4 +153,11 @@ chat_llm = ChatHuggingFace(
153153

154154

155155
+ `flash_attn`: Enable FlashAttention-2.
156-
+
156+
157+
## Merge Lora model
158+
159+
```python
160+
from langchain_llm import apply_lora
161+
162+
apply_lora("base_model_path", "lora_path", "target_model_path")
163+
```

libs/langchain_llm/langchain_llm/_huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def _stream(
594594
token=chunk.text, verbose=self.verbose, log_probs=logprobs
595595
)
596596

597-
def _to_chat_prompt(self, messages: List[BaseMessage]) -> Union[List[int], Dict[str, Any]]:
597+
def _to_chat_prompt(self, messages: List[BaseMessage]) -> List[Dict]:
598598
"""Convert a list of messages into a prompt format expected by wrapped LLM."""
599599
if not messages:
600600
raise ValueError("at least one HumanMessage must be provided")
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from .template import get_prompt_adapter
2+
from .model import load_model_and_tokenizer

libs/langchain_llm/langchain_llm/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch
44
from loguru import logger
55
from peft import PeftModel
6-
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel
6+
from transformers import AutoTokenizer, AutoModelForCausalLM
77

88

99
def apply_lora(
@@ -12,7 +12,7 @@ def apply_lora(
1212
target_model_path: str,
1313
max_shard_size: Optional[str] = "2GB",
1414
safe_serialization: Optional[bool] = True,
15-
) -> PreTrainedModel:
15+
):
1616

1717
logger.info(f"Loading the base model from {base_model_path}")
1818
base = AutoModelForCausalLM.from_pretrained(

0 commit comments

Comments
 (0)