1
- import langchain .chains .rl_chain as rl_chain
1
+ import langchain .chains .rl_chain .pick_best_chain as pick_best_chain
2
+ import langchain .chains .rl_chain .base as rl_chain
2
3
from test_utils import MockEncoder
3
4
import pytest
4
5
from langchain .prompts .prompt import PromptTemplate
@@ -17,7 +18,7 @@ def setup():
17
18
18
19
def test_multiple_ToSelectFrom_throws ():
19
20
llm , PROMPT = setup ()
20
- chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
21
+ chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
21
22
actions = ["0" , "1" , "2" ]
22
23
with pytest .raises (ValueError ):
23
24
chain .run (
@@ -29,15 +30,15 @@ def test_multiple_ToSelectFrom_throws():
29
30
30
31
def test_missing_basedOn_from_throws ():
31
32
llm , PROMPT = setup ()
32
- chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
33
+ chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
33
34
actions = ["0" , "1" , "2" ]
34
35
with pytest .raises (ValueError ):
35
36
chain .run (action = rl_chain .ToSelectFrom (actions ))
36
37
37
38
38
39
def test_ToSelectFrom_not_a_list_throws ():
39
40
llm , PROMPT = setup ()
40
- chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
41
+ chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
41
42
actions = {"actions" : ["0" , "1" , "2" ]}
42
43
with pytest .raises (ValueError ):
43
44
chain .run (
@@ -50,7 +51,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
50
51
llm , PROMPT = setup ()
51
52
# this LLM returns a number so that the auto validator will return that
52
53
auto_val_llm = FakeListChatModel (responses = ["3" ])
53
- chain = rl_chain .PickBest .from_llm (
54
+ chain = pick_best_chain .PickBest .from_llm (
54
55
llm = llm ,
55
56
prompt = PROMPT ,
56
57
selection_scorer = rl_chain .AutoSelectionScorer (llm = auto_val_llm ),
@@ -71,7 +72,7 @@ def test_update_with_delayed_score_force():
71
72
llm , PROMPT = setup ()
72
73
# this LLM returns a number so that the auto validator will return that
73
74
auto_val_llm = FakeListChatModel (responses = ["3" ])
74
- chain = rl_chain .PickBest .from_llm (
75
+ chain = pick_best_chain .PickBest .from_llm (
75
76
llm = llm ,
76
77
prompt = PROMPT ,
77
78
selection_scorer = rl_chain .AutoSelectionScorer (llm = auto_val_llm ),
@@ -92,7 +93,7 @@ def test_update_with_delayed_score_force():
92
93
93
94
def test_update_with_delayed_score ():
94
95
llm , PROMPT = setup ()
95
- chain = rl_chain .PickBest .from_llm (
96
+ chain = pick_best_chain .PickBest .from_llm (
96
97
llm = llm , prompt = PROMPT , selection_scorer = None
97
98
)
98
99
actions = ["0" , "1" , "2" ]
@@ -115,7 +116,7 @@ def score_response(self, inputs, llm_response: str) -> float:
115
116
score = 200
116
117
return score
117
118
118
- chain = rl_chain .PickBest .from_llm (
119
+ chain = pick_best_chain .PickBest .from_llm (
119
120
llm = llm , prompt = PROMPT , selection_scorer = CustomSelectionScorer ()
120
121
)
121
122
actions = ["0" , "1" , "2" ]
@@ -130,8 +131,8 @@ def score_response(self, inputs, llm_response: str) -> float:
130
131
131
132
def test_default_embeddings ():
132
133
llm , PROMPT = setup ()
133
- feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
134
- chain = rl_chain .PickBest .from_llm (
134
+ feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
135
+ chain = pick_best_chain .PickBest .from_llm (
135
136
llm = llm , prompt = PROMPT , feature_embedder = feature_embedder
136
137
)
137
138
@@ -163,8 +164,8 @@ def test_default_embeddings():
163
164
164
165
def test_default_embeddings_off ():
165
166
llm , PROMPT = setup ()
166
- feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
167
- chain = rl_chain .PickBest .from_llm (
167
+ feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
168
+ chain = pick_best_chain .PickBest .from_llm (
168
169
llm = llm , prompt = PROMPT , feature_embedder = feature_embedder , auto_embed = False
169
170
)
170
171
@@ -188,8 +189,8 @@ def test_default_embeddings_off():
188
189
189
190
def test_default_embeddings_mixed_w_explicit_user_embeddings ():
190
191
llm , PROMPT = setup ()
191
- feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
192
- chain = rl_chain .PickBest .from_llm (
192
+ feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
193
+ chain = pick_best_chain .PickBest .from_llm (
193
194
llm = llm , prompt = PROMPT , feature_embedder = feature_embedder
194
195
)
195
196
@@ -223,7 +224,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
223
224
def test_default_no_scorer_specified ():
224
225
_ , PROMPT = setup ()
225
226
chain_llm = FakeListChatModel (responses = [100 ])
226
- chain = rl_chain .PickBest .from_llm (llm = chain_llm , prompt = PROMPT )
227
+ chain = pick_best_chain .PickBest .from_llm (llm = chain_llm , prompt = PROMPT )
227
228
response = chain .run (
228
229
User = rl_chain .BasedOn ("Context" ),
229
230
action = rl_chain .ToSelectFrom (["0" , "1" , "2" ]),
@@ -236,7 +237,7 @@ def test_default_no_scorer_specified():
236
237
237
238
def test_explicitly_no_scorer ():
238
239
llm , PROMPT = setup ()
239
- chain = rl_chain .PickBest .from_llm (
240
+ chain = pick_best_chain .PickBest .from_llm (
240
241
llm = llm , prompt = PROMPT , selection_scorer = None
241
242
)
242
243
response = chain .run (
@@ -252,7 +253,7 @@ def test_explicitly_no_scorer():
252
253
def test_auto_scorer_with_user_defined_llm ():
253
254
llm , PROMPT = setup ()
254
255
scorer_llm = FakeListChatModel (responses = [300 ])
255
- chain = rl_chain .PickBest .from_llm (
256
+ chain = pick_best_chain .PickBest .from_llm (
256
257
llm = llm ,
257
258
prompt = PROMPT ,
258
259
selection_scorer = rl_chain .AutoSelectionScorer (llm = scorer_llm ),
@@ -269,7 +270,7 @@ def test_auto_scorer_with_user_defined_llm():
269
270
270
271
def test_calling_chain_w_reserved_inputs_throws ():
271
272
llm , PROMPT = setup ()
272
- chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
273
+ chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
273
274
with pytest .raises (ValueError ):
274
275
chain .run (
275
276
User = rl_chain .BasedOn ("Context" ),
0 commit comments