1- import langchain .chains .rl_chain as rl_chain
1+ import langchain .chains .rl_chain .pick_best_chain as pick_best_chain
2+ import langchain .chains .rl_chain .base as rl_chain
23from test_utils import MockEncoder
34import pytest
45from langchain .prompts .prompt import PromptTemplate
@@ -17,7 +18,7 @@ def setup():
1718
1819def test_multiple_ToSelectFrom_throws ():
1920 llm , PROMPT = setup ()
20- chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
21+ chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
2122 actions = ["0" , "1" , "2" ]
2223 with pytest .raises (ValueError ):
2324 chain .run (
@@ -29,15 +30,15 @@ def test_multiple_ToSelectFrom_throws():
2930
3031def test_missing_basedOn_from_throws ():
3132 llm , PROMPT = setup ()
32- chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
33+ chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
3334 actions = ["0" , "1" , "2" ]
3435 with pytest .raises (ValueError ):
3536 chain .run (action = rl_chain .ToSelectFrom (actions ))
3637
3738
3839def test_ToSelectFrom_not_a_list_throws ():
3940 llm , PROMPT = setup ()
40- chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
41+ chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
4142 actions = {"actions" : ["0" , "1" , "2" ]}
4243 with pytest .raises (ValueError ):
4344 chain .run (
@@ -50,7 +51,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
5051 llm , PROMPT = setup ()
5152 # this LLM returns a number so that the auto validator will return that
5253 auto_val_llm = FakeListChatModel (responses = ["3" ])
53- chain = rl_chain .PickBest .from_llm (
54+ chain = pick_best_chain .PickBest .from_llm (
5455 llm = llm ,
5556 prompt = PROMPT ,
5657 selection_scorer = rl_chain .AutoSelectionScorer (llm = auto_val_llm ),
@@ -71,7 +72,7 @@ def test_update_with_delayed_score_force():
7172 llm , PROMPT = setup ()
7273 # this LLM returns a number so that the auto validator will return that
7374 auto_val_llm = FakeListChatModel (responses = ["3" ])
74- chain = rl_chain .PickBest .from_llm (
75+ chain = pick_best_chain .PickBest .from_llm (
7576 llm = llm ,
7677 prompt = PROMPT ,
7778 selection_scorer = rl_chain .AutoSelectionScorer (llm = auto_val_llm ),
@@ -92,7 +93,7 @@ def test_update_with_delayed_score_force():
9293
9394def test_update_with_delayed_score ():
9495 llm , PROMPT = setup ()
95- chain = rl_chain .PickBest .from_llm (
96+ chain = pick_best_chain .PickBest .from_llm (
9697 llm = llm , prompt = PROMPT , selection_scorer = None
9798 )
9899 actions = ["0" , "1" , "2" ]
@@ -115,7 +116,7 @@ def score_response(self, inputs, llm_response: str) -> float:
115116 score = 200
116117 return score
117118
118- chain = rl_chain .PickBest .from_llm (
119+ chain = pick_best_chain .PickBest .from_llm (
119120 llm = llm , prompt = PROMPT , selection_scorer = CustomSelectionScorer ()
120121 )
121122 actions = ["0" , "1" , "2" ]
@@ -130,8 +131,8 @@ def score_response(self, inputs, llm_response: str) -> float:
130131
131132def test_default_embeddings ():
132133 llm , PROMPT = setup ()
133- feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
134- chain = rl_chain .PickBest .from_llm (
134+ feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
135+ chain = pick_best_chain .PickBest .from_llm (
135136 llm = llm , prompt = PROMPT , feature_embedder = feature_embedder
136137 )
137138
@@ -163,8 +164,8 @@ def test_default_embeddings():
163164
164165def test_default_embeddings_off ():
165166 llm , PROMPT = setup ()
166- feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
167- chain = rl_chain .PickBest .from_llm (
167+ feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
168+ chain = pick_best_chain .PickBest .from_llm (
168169 llm = llm , prompt = PROMPT , feature_embedder = feature_embedder , auto_embed = False
169170 )
170171
@@ -188,8 +189,8 @@ def test_default_embeddings_off():
188189
189190def test_default_embeddings_mixed_w_explicit_user_embeddings ():
190191 llm , PROMPT = setup ()
191- feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
192- chain = rl_chain .PickBest .from_llm (
192+ feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
193+ chain = pick_best_chain .PickBest .from_llm (
193194 llm = llm , prompt = PROMPT , feature_embedder = feature_embedder
194195 )
195196
@@ -223,7 +224,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
223224def test_default_no_scorer_specified ():
224225 _ , PROMPT = setup ()
225226 chain_llm = FakeListChatModel (responses = [100 ])
226- chain = rl_chain .PickBest .from_llm (llm = chain_llm , prompt = PROMPT )
227+ chain = pick_best_chain .PickBest .from_llm (llm = chain_llm , prompt = PROMPT )
227228 response = chain .run (
228229 User = rl_chain .BasedOn ("Context" ),
229230 action = rl_chain .ToSelectFrom (["0" , "1" , "2" ]),
@@ -236,7 +237,7 @@ def test_default_no_scorer_specified():
236237
237238def test_explicitly_no_scorer ():
238239 llm , PROMPT = setup ()
239- chain = rl_chain .PickBest .from_llm (
240+ chain = pick_best_chain .PickBest .from_llm (
240241 llm = llm , prompt = PROMPT , selection_scorer = None
241242 )
242243 response = chain .run (
@@ -252,7 +253,7 @@ def test_explicitly_no_scorer():
252253def test_auto_scorer_with_user_defined_llm ():
253254 llm , PROMPT = setup ()
254255 scorer_llm = FakeListChatModel (responses = [300 ])
255- chain = rl_chain .PickBest .from_llm (
256+ chain = pick_best_chain .PickBest .from_llm (
256257 llm = llm ,
257258 prompt = PROMPT ,
258259 selection_scorer = rl_chain .AutoSelectionScorer (llm = scorer_llm ),
@@ -269,7 +270,7 @@ def test_auto_scorer_with_user_defined_llm():
269270
270271def test_calling_chain_w_reserved_inputs_throws ():
271272 llm , PROMPT = setup ()
272- chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
273+ chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
273274 with pytest .raises (ValueError ):
274275 chain .run (
275276 User = rl_chain .BasedOn ("Context" ),
0 commit comments