1
- import langchain .chains .rl_chain . pick_best_chain as pick_best_chain
1
+ import langchain .chains .rl_chain as rl_chain
2
2
from test_utils import MockEncoder
3
3
import pytest
4
4
from langchain .prompts .prompt import PromptTemplate
@@ -17,48 +17,48 @@ def setup():
17
17
18
18
def test_multiple_ToSelectFrom_throws ():
19
19
llm , PROMPT = setup ()
20
- chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
20
+ chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
21
21
actions = ["0" , "1" , "2" ]
22
22
with pytest .raises (ValueError ):
23
23
chain .run (
24
- User = pick_best_chain . base .BasedOn ("Context" ),
25
- action = pick_best_chain . base .ToSelectFrom (actions ),
26
- another_action = pick_best_chain . base .ToSelectFrom (actions ),
24
+ User = rl_chain .BasedOn ("Context" ),
25
+ action = rl_chain .ToSelectFrom (actions ),
26
+ another_action = rl_chain .ToSelectFrom (actions ),
27
27
)
28
28
29
29
30
30
def test_missing_basedOn_from_throws ():
31
31
llm , PROMPT = setup ()
32
- chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
32
+ chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
33
33
actions = ["0" , "1" , "2" ]
34
34
with pytest .raises (ValueError ):
35
- chain .run (action = pick_best_chain . base .ToSelectFrom (actions ))
35
+ chain .run (action = rl_chain .ToSelectFrom (actions ))
36
36
37
37
38
38
def test_ToSelectFrom_not_a_list_throws ():
39
39
llm , PROMPT = setup ()
40
- chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
40
+ chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
41
41
actions = {"actions" : ["0" , "1" , "2" ]}
42
42
with pytest .raises (ValueError ):
43
43
chain .run (
44
- User = pick_best_chain . base .BasedOn ("Context" ),
45
- action = pick_best_chain . base .ToSelectFrom (actions ),
44
+ User = rl_chain .BasedOn ("Context" ),
45
+ action = rl_chain .ToSelectFrom (actions ),
46
46
)
47
47
48
48
49
49
def test_update_with_delayed_score_with_auto_validator_throws ():
50
50
llm , PROMPT = setup ()
51
51
# this LLM returns a number so that the auto validator will return that
52
52
auto_val_llm = FakeListChatModel (responses = ["3" ])
53
- chain = pick_best_chain .PickBest .from_llm (
53
+ chain = rl_chain .PickBest .from_llm (
54
54
llm = llm ,
55
55
prompt = PROMPT ,
56
- selection_scorer = pick_best_chain . base .AutoSelectionScorer (llm = auto_val_llm ),
56
+ selection_scorer = rl_chain .AutoSelectionScorer (llm = auto_val_llm ),
57
57
)
58
58
actions = ["0" , "1" , "2" ]
59
59
response = chain .run (
60
- User = pick_best_chain . base .BasedOn ("Context" ),
61
- action = pick_best_chain . base .ToSelectFrom (actions ),
60
+ User = rl_chain .BasedOn ("Context" ),
61
+ action = rl_chain .ToSelectFrom (actions ),
62
62
)
63
63
assert response ["response" ] == "hey"
64
64
selection_metadata = response ["selection_metadata" ]
@@ -71,15 +71,15 @@ def test_update_with_delayed_score_force():
71
71
llm , PROMPT = setup ()
72
72
# this LLM returns a number so that the auto validator will return that
73
73
auto_val_llm = FakeListChatModel (responses = ["3" ])
74
- chain = pick_best_chain .PickBest .from_llm (
74
+ chain = rl_chain .PickBest .from_llm (
75
75
llm = llm ,
76
76
prompt = PROMPT ,
77
- selection_scorer = pick_best_chain . base .AutoSelectionScorer (llm = auto_val_llm ),
77
+ selection_scorer = rl_chain .AutoSelectionScorer (llm = auto_val_llm ),
78
78
)
79
79
actions = ["0" , "1" , "2" ]
80
80
response = chain .run (
81
- User = pick_best_chain . base .BasedOn ("Context" ),
82
- action = pick_best_chain . base .ToSelectFrom (actions ),
81
+ User = rl_chain .BasedOn ("Context" ),
82
+ action = rl_chain .ToSelectFrom (actions ),
83
83
)
84
84
assert response ["response" ] == "hey"
85
85
selection_metadata = response ["selection_metadata" ]
@@ -92,13 +92,13 @@ def test_update_with_delayed_score_force():
92
92
93
93
def test_update_with_delayed_score ():
94
94
llm , PROMPT = setup ()
95
- chain = pick_best_chain .PickBest .from_llm (
95
+ chain = rl_chain .PickBest .from_llm (
96
96
llm = llm , prompt = PROMPT , selection_scorer = None
97
97
)
98
98
actions = ["0" , "1" , "2" ]
99
99
response = chain .run (
100
- User = pick_best_chain . base .BasedOn ("Context" ),
101
- action = pick_best_chain . base .ToSelectFrom (actions ),
100
+ User = rl_chain .BasedOn ("Context" ),
101
+ action = rl_chain .ToSelectFrom (actions ),
102
102
)
103
103
assert response ["response" ] == "hey"
104
104
selection_metadata = response ["selection_metadata" ]
@@ -110,18 +110,18 @@ def test_update_with_delayed_score():
110
110
def test_user_defined_scorer ():
111
111
llm , PROMPT = setup ()
112
112
113
- class CustomSelectionScorer (pick_best_chain . base .SelectionScorer ):
113
+ class CustomSelectionScorer (rl_chain .SelectionScorer ):
114
114
def score_response (self , inputs , llm_response : str ) -> float :
115
115
score = 200
116
116
return score
117
117
118
- chain = pick_best_chain .PickBest .from_llm (
118
+ chain = rl_chain .PickBest .from_llm (
119
119
llm = llm , prompt = PROMPT , selection_scorer = CustomSelectionScorer ()
120
120
)
121
121
actions = ["0" , "1" , "2" ]
122
122
response = chain .run (
123
- User = pick_best_chain . base .BasedOn ("Context" ),
124
- action = pick_best_chain . base .ToSelectFrom (actions ),
123
+ User = rl_chain .BasedOn ("Context" ),
124
+ action = rl_chain .ToSelectFrom (actions ),
125
125
)
126
126
assert response ["response" ] == "hey"
127
127
selection_metadata = response ["selection_metadata" ]
@@ -130,8 +130,8 @@ def score_response(self, inputs, llm_response: str) -> float:
130
130
131
131
def test_default_embeddings ():
132
132
llm , PROMPT = setup ()
133
- feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
134
- chain = pick_best_chain .PickBest .from_llm (
133
+ feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
134
+ chain = rl_chain .PickBest .from_llm (
135
135
llm = llm , prompt = PROMPT , feature_embedder = feature_embedder
136
136
)
137
137
@@ -153,8 +153,8 @@ def test_default_embeddings():
153
153
actions = [str1 , str2 , str3 ]
154
154
155
155
response = chain .run (
156
- User = pick_best_chain . base .BasedOn (ctx_str_1 ),
157
- action = pick_best_chain . base .ToSelectFrom (actions ),
156
+ User = rl_chain .BasedOn (ctx_str_1 ),
157
+ action = rl_chain .ToSelectFrom (actions ),
158
158
)
159
159
selection_metadata = response ["selection_metadata" ]
160
160
vw_str = feature_embedder .format (selection_metadata )
@@ -163,8 +163,8 @@ def test_default_embeddings():
163
163
164
164
def test_default_embeddings_off ():
165
165
llm , PROMPT = setup ()
166
- feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
167
- chain = pick_best_chain .PickBest .from_llm (
166
+ feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
167
+ chain = rl_chain .PickBest .from_llm (
168
168
llm = llm , prompt = PROMPT , feature_embedder = feature_embedder , auto_embed = False
169
169
)
170
170
@@ -178,8 +178,8 @@ def test_default_embeddings_off():
178
178
actions = [str1 , str2 , str3 ]
179
179
180
180
response = chain .run (
181
- User = pick_best_chain . base .BasedOn (ctx_str_1 ),
182
- action = pick_best_chain . base .ToSelectFrom (actions ),
181
+ User = rl_chain .BasedOn (ctx_str_1 ),
182
+ action = rl_chain .ToSelectFrom (actions ),
183
183
)
184
184
selection_metadata = response ["selection_metadata" ]
185
185
vw_str = feature_embedder .format (selection_metadata )
@@ -188,8 +188,8 @@ def test_default_embeddings_off():
188
188
189
189
def test_default_embeddings_mixed_w_explicit_user_embeddings ():
190
190
llm , PROMPT = setup ()
191
- feature_embedder = pick_best_chain .PickBestFeatureEmbedder (model = MockEncoder ())
192
- chain = pick_best_chain .PickBest .from_llm (
191
+ feature_embedder = rl_chain .PickBestFeatureEmbedder (model = MockEncoder ())
192
+ chain = rl_chain .PickBest .from_llm (
193
193
llm = llm , prompt = PROMPT , feature_embedder = feature_embedder
194
194
)
195
195
@@ -208,12 +208,12 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
208
208
209
209
expected = f"""shared |User { encoded_ctx_str_1 } |User2 { ctx_str_2 + " " + encoded_ctx_str_2 } \n |action { str1 + " " + encoded_str1 } \n |action { str2 + " " + encoded_str2 } \n |action { encoded_str3 } """
210
210
211
- actions = [str1 , str2 , pick_best_chain . base .Embed (str3 )]
211
+ actions = [str1 , str2 , rl_chain .Embed (str3 )]
212
212
213
213
response = chain .run (
214
- User = pick_best_chain . base . BasedOn (pick_best_chain . base .Embed (ctx_str_1 )),
215
- User2 = pick_best_chain . base .BasedOn (ctx_str_2 ),
216
- action = pick_best_chain . base .ToSelectFrom (actions ),
214
+ User = rl_chain . BasedOn (rl_chain .Embed (ctx_str_1 )),
215
+ User2 = rl_chain .BasedOn (ctx_str_2 ),
216
+ action = rl_chain .ToSelectFrom (actions ),
217
217
)
218
218
selection_metadata = response ["selection_metadata" ]
219
219
vw_str = feature_embedder .format (selection_metadata )
@@ -223,10 +223,10 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
223
223
def test_default_no_scorer_specified ():
224
224
_ , PROMPT = setup ()
225
225
chain_llm = FakeListChatModel (responses = [100 ])
226
- chain = pick_best_chain .PickBest .from_llm (llm = chain_llm , prompt = PROMPT )
226
+ chain = rl_chain .PickBest .from_llm (llm = chain_llm , prompt = PROMPT )
227
227
response = chain .run (
228
- User = pick_best_chain . base .BasedOn ("Context" ),
229
- action = pick_best_chain . base .ToSelectFrom (["0" , "1" , "2" ]),
228
+ User = rl_chain .BasedOn ("Context" ),
229
+ action = rl_chain .ToSelectFrom (["0" , "1" , "2" ]),
230
230
)
231
231
# chain llm used for both basic prompt and for scoring
232
232
assert response ["response" ] == "100"
@@ -236,12 +236,12 @@ def test_default_no_scorer_specified():
236
236
237
237
def test_explicitly_no_scorer ():
238
238
llm , PROMPT = setup ()
239
- chain = pick_best_chain .PickBest .from_llm (
239
+ chain = rl_chain .PickBest .from_llm (
240
240
llm = llm , prompt = PROMPT , selection_scorer = None
241
241
)
242
242
response = chain .run (
243
- User = pick_best_chain . base .BasedOn ("Context" ),
244
- action = pick_best_chain . base .ToSelectFrom (["0" , "1" , "2" ]),
243
+ User = rl_chain .BasedOn ("Context" ),
244
+ action = rl_chain .ToSelectFrom (["0" , "1" , "2" ]),
245
245
)
246
246
# chain llm used for both basic prompt and for scoring
247
247
assert response ["response" ] == "hey"
@@ -252,14 +252,14 @@ def test_explicitly_no_scorer():
252
252
def test_auto_scorer_with_user_defined_llm ():
253
253
llm , PROMPT = setup ()
254
254
scorer_llm = FakeListChatModel (responses = [300 ])
255
- chain = pick_best_chain .PickBest .from_llm (
255
+ chain = rl_chain .PickBest .from_llm (
256
256
llm = llm ,
257
257
prompt = PROMPT ,
258
- selection_scorer = pick_best_chain . base .AutoSelectionScorer (llm = scorer_llm ),
258
+ selection_scorer = rl_chain .AutoSelectionScorer (llm = scorer_llm ),
259
259
)
260
260
response = chain .run (
261
- User = pick_best_chain . base .BasedOn ("Context" ),
262
- action = pick_best_chain . base .ToSelectFrom (["0" , "1" , "2" ]),
261
+ User = rl_chain .BasedOn ("Context" ),
262
+ action = rl_chain .ToSelectFrom (["0" , "1" , "2" ]),
263
263
)
264
264
# chain llm used for both basic prompt and for scoring
265
265
assert response ["response" ] == "hey"
@@ -269,17 +269,17 @@ def test_auto_scorer_with_user_defined_llm():
269
269
270
270
def test_calling_chain_w_reserved_inputs_throws ():
271
271
llm , PROMPT = setup ()
272
- chain = pick_best_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
272
+ chain = rl_chain .PickBest .from_llm (llm = llm , prompt = PROMPT )
273
273
with pytest .raises (ValueError ):
274
274
chain .run (
275
- User = pick_best_chain . base .BasedOn ("Context" ),
276
- rl_chain_selected_based_on = pick_best_chain . base .ToSelectFrom (
275
+ User = rl_chain .BasedOn ("Context" ),
276
+ rl_chain_selected_based_on = rl_chain .ToSelectFrom (
277
277
["0" , "1" , "2" ]
278
278
),
279
279
)
280
280
281
281
with pytest .raises (ValueError ):
282
282
chain .run (
283
- User = pick_best_chain . base .BasedOn ("Context" ),
284
- rl_chain_selected = pick_best_chain . base .ToSelectFrom (["0" , "1" , "2" ]),
283
+ User = rl_chain .BasedOn ("Context" ),
284
+ rl_chain_selected = rl_chain .ToSelectFrom (["0" , "1" , "2" ]),
285
285
)
0 commit comments