Skip to content

Commit b422dc0

Browse files
committed
fix imports
1 parent c37fd29 commit b422dc0

File tree

6 files changed

+124
-124
lines changed

6 files changed

+124
-124
lines changed

libs/langchain/langchain/chains/rl_chain/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from .pick_best_chain import PickBest
2-
from .rl_chain_base import (
1+
from langchain.chains.rl_chain.pick_best_chain import PickBest
2+
from langchain.chains.rl_chain.rl_chain_base import (
33
Embed,
44
BasedOn,
55
ToSelectFrom,

libs/langchain/langchain/chains/rl_chain/pick_best_chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from . import rl_chain_base as base
3+
import langchain.chains.rl_chain.rl_chain_base as base
44

55
from langchain.callbacks.manager import CallbackManagerForChainRun
66
from langchain.chains.base import Chain

libs/langchain/langchain/chains/rl_chain/rl_chain_base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from abc import ABC, abstractmethod
77

88
import vowpal_wabbit_next as vw
9-
from .vw_logger import VwLogger
10-
from .model_repository import ModelRepository
11-
from .metrics import MetricsTracker
9+
from langchain.chains.rl_chain.vw_logger import VwLogger
10+
from langchain.chains.rl_chain.model_repository import ModelRepository
11+
from langchain.chains.rl_chain.metrics import MetricsTracker
1212
from langchain.prompts import BasePromptTemplate
1313

1414
from langchain.pydantic_v1 import Extra, BaseModel, root_validator

libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
1+
import langchain.chains.rl_chain as rl_chain
22
from test_utils import MockEncoder
33
import pytest
44
from langchain.prompts.prompt import PromptTemplate
@@ -17,48 +17,48 @@ def setup():
1717

1818
def test_multiple_ToSelectFrom_throws():
1919
llm, PROMPT = setup()
20-
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
20+
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
2121
actions = ["0", "1", "2"]
2222
with pytest.raises(ValueError):
2323
chain.run(
24-
User=pick_best_chain.base.BasedOn("Context"),
25-
action=pick_best_chain.base.ToSelectFrom(actions),
26-
another_action=pick_best_chain.base.ToSelectFrom(actions),
24+
User=rl_chain.BasedOn("Context"),
25+
action=rl_chain.ToSelectFrom(actions),
26+
another_action=rl_chain.ToSelectFrom(actions),
2727
)
2828

2929

3030
def test_missing_basedOn_from_throws():
3131
llm, PROMPT = setup()
32-
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
32+
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
3333
actions = ["0", "1", "2"]
3434
with pytest.raises(ValueError):
35-
chain.run(action=pick_best_chain.base.ToSelectFrom(actions))
35+
chain.run(action=rl_chain.ToSelectFrom(actions))
3636

3737

3838
def test_ToSelectFrom_not_a_list_throws():
3939
llm, PROMPT = setup()
40-
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
40+
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
4141
actions = {"actions": ["0", "1", "2"]}
4242
with pytest.raises(ValueError):
4343
chain.run(
44-
User=pick_best_chain.base.BasedOn("Context"),
45-
action=pick_best_chain.base.ToSelectFrom(actions),
44+
User=rl_chain.BasedOn("Context"),
45+
action=rl_chain.ToSelectFrom(actions),
4646
)
4747

4848

4949
def test_update_with_delayed_score_with_auto_validator_throws():
5050
llm, PROMPT = setup()
5151
# this LLM returns a number so that the auto validator will return that
5252
auto_val_llm = FakeListChatModel(responses=["3"])
53-
chain = pick_best_chain.PickBest.from_llm(
53+
chain = rl_chain.PickBest.from_llm(
5454
llm=llm,
5555
prompt=PROMPT,
56-
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=auto_val_llm),
56+
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
5757
)
5858
actions = ["0", "1", "2"]
5959
response = chain.run(
60-
User=pick_best_chain.base.BasedOn("Context"),
61-
action=pick_best_chain.base.ToSelectFrom(actions),
60+
User=rl_chain.BasedOn("Context"),
61+
action=rl_chain.ToSelectFrom(actions),
6262
)
6363
assert response["response"] == "hey"
6464
selection_metadata = response["selection_metadata"]
@@ -71,15 +71,15 @@ def test_update_with_delayed_score_force():
7171
llm, PROMPT = setup()
7272
# this LLM returns a number so that the auto validator will return that
7373
auto_val_llm = FakeListChatModel(responses=["3"])
74-
chain = pick_best_chain.PickBest.from_llm(
74+
chain = rl_chain.PickBest.from_llm(
7575
llm=llm,
7676
prompt=PROMPT,
77-
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=auto_val_llm),
77+
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
7878
)
7979
actions = ["0", "1", "2"]
8080
response = chain.run(
81-
User=pick_best_chain.base.BasedOn("Context"),
82-
action=pick_best_chain.base.ToSelectFrom(actions),
81+
User=rl_chain.BasedOn("Context"),
82+
action=rl_chain.ToSelectFrom(actions),
8383
)
8484
assert response["response"] == "hey"
8585
selection_metadata = response["selection_metadata"]
@@ -92,13 +92,13 @@ def test_update_with_delayed_score_force():
9292

9393
def test_update_with_delayed_score():
9494
llm, PROMPT = setup()
95-
chain = pick_best_chain.PickBest.from_llm(
95+
chain = rl_chain.PickBest.from_llm(
9696
llm=llm, prompt=PROMPT, selection_scorer=None
9797
)
9898
actions = ["0", "1", "2"]
9999
response = chain.run(
100-
User=pick_best_chain.base.BasedOn("Context"),
101-
action=pick_best_chain.base.ToSelectFrom(actions),
100+
User=rl_chain.BasedOn("Context"),
101+
action=rl_chain.ToSelectFrom(actions),
102102
)
103103
assert response["response"] == "hey"
104104
selection_metadata = response["selection_metadata"]
@@ -110,18 +110,18 @@ def test_update_with_delayed_score():
110110
def test_user_defined_scorer():
111111
llm, PROMPT = setup()
112112

113-
class CustomSelectionScorer(pick_best_chain.base.SelectionScorer):
113+
class CustomSelectionScorer(rl_chain.SelectionScorer):
114114
def score_response(self, inputs, llm_response: str) -> float:
115115
score = 200
116116
return score
117117

118-
chain = pick_best_chain.PickBest.from_llm(
118+
chain = rl_chain.PickBest.from_llm(
119119
llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer()
120120
)
121121
actions = ["0", "1", "2"]
122122
response = chain.run(
123-
User=pick_best_chain.base.BasedOn("Context"),
124-
action=pick_best_chain.base.ToSelectFrom(actions),
123+
User=rl_chain.BasedOn("Context"),
124+
action=rl_chain.ToSelectFrom(actions),
125125
)
126126
assert response["response"] == "hey"
127127
selection_metadata = response["selection_metadata"]
@@ -130,8 +130,8 @@ def score_response(self, inputs, llm_response: str) -> float:
130130

131131
def test_default_embeddings():
132132
llm, PROMPT = setup()
133-
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
134-
chain = pick_best_chain.PickBest.from_llm(
133+
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
134+
chain = rl_chain.PickBest.from_llm(
135135
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
136136
)
137137

@@ -153,8 +153,8 @@ def test_default_embeddings():
153153
actions = [str1, str2, str3]
154154

155155
response = chain.run(
156-
User=pick_best_chain.base.BasedOn(ctx_str_1),
157-
action=pick_best_chain.base.ToSelectFrom(actions),
156+
User=rl_chain.BasedOn(ctx_str_1),
157+
action=rl_chain.ToSelectFrom(actions),
158158
)
159159
selection_metadata = response["selection_metadata"]
160160
vw_str = feature_embedder.format(selection_metadata)
@@ -163,8 +163,8 @@ def test_default_embeddings():
163163

164164
def test_default_embeddings_off():
165165
llm, PROMPT = setup()
166-
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
167-
chain = pick_best_chain.PickBest.from_llm(
166+
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
167+
chain = rl_chain.PickBest.from_llm(
168168
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
169169
)
170170

@@ -178,8 +178,8 @@ def test_default_embeddings_off():
178178
actions = [str1, str2, str3]
179179

180180
response = chain.run(
181-
User=pick_best_chain.base.BasedOn(ctx_str_1),
182-
action=pick_best_chain.base.ToSelectFrom(actions),
181+
User=rl_chain.BasedOn(ctx_str_1),
182+
action=rl_chain.ToSelectFrom(actions),
183183
)
184184
selection_metadata = response["selection_metadata"]
185185
vw_str = feature_embedder.format(selection_metadata)
@@ -188,8 +188,8 @@ def test_default_embeddings_off():
188188

189189
def test_default_embeddings_mixed_w_explicit_user_embeddings():
190190
llm, PROMPT = setup()
191-
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
192-
chain = pick_best_chain.PickBest.from_llm(
191+
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
192+
chain = rl_chain.PickBest.from_llm(
193193
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
194194
)
195195

@@ -208,12 +208,12 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
208208

209209
expected = f"""shared |User {encoded_ctx_str_1} |User2 {ctx_str_2 + " " + encoded_ctx_str_2} \n|action {str1 + " " + encoded_str1} \n|action {str2 + " " + encoded_str2} \n|action {encoded_str3} """
210210

211-
actions = [str1, str2, pick_best_chain.base.Embed(str3)]
211+
actions = [str1, str2, rl_chain.Embed(str3)]
212212

213213
response = chain.run(
214-
User=pick_best_chain.base.BasedOn(pick_best_chain.base.Embed(ctx_str_1)),
215-
User2=pick_best_chain.base.BasedOn(ctx_str_2),
216-
action=pick_best_chain.base.ToSelectFrom(actions),
214+
User=rl_chain.BasedOn(rl_chain.Embed(ctx_str_1)),
215+
User2=rl_chain.BasedOn(ctx_str_2),
216+
action=rl_chain.ToSelectFrom(actions),
217217
)
218218
selection_metadata = response["selection_metadata"]
219219
vw_str = feature_embedder.format(selection_metadata)
@@ -223,10 +223,10 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
223223
def test_default_no_scorer_specified():
224224
_, PROMPT = setup()
225225
chain_llm = FakeListChatModel(responses=[100])
226-
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
226+
chain = rl_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
227227
response = chain.run(
228-
User=pick_best_chain.base.BasedOn("Context"),
229-
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
228+
User=rl_chain.BasedOn("Context"),
229+
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
230230
)
231231
# chain llm used for both basic prompt and for scoring
232232
assert response["response"] == "100"
@@ -236,12 +236,12 @@ def test_default_no_scorer_specified():
236236

237237
def test_explicitly_no_scorer():
238238
llm, PROMPT = setup()
239-
chain = pick_best_chain.PickBest.from_llm(
239+
chain = rl_chain.PickBest.from_llm(
240240
llm=llm, prompt=PROMPT, selection_scorer=None
241241
)
242242
response = chain.run(
243-
User=pick_best_chain.base.BasedOn("Context"),
244-
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
243+
User=rl_chain.BasedOn("Context"),
244+
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
245245
)
246246
# chain llm used for both basic prompt and for scoring
247247
assert response["response"] == "hey"
@@ -252,14 +252,14 @@ def test_explicitly_no_scorer():
252252
def test_auto_scorer_with_user_defined_llm():
253253
llm, PROMPT = setup()
254254
scorer_llm = FakeListChatModel(responses=[300])
255-
chain = pick_best_chain.PickBest.from_llm(
255+
chain = rl_chain.PickBest.from_llm(
256256
llm=llm,
257257
prompt=PROMPT,
258-
selection_scorer=pick_best_chain.base.AutoSelectionScorer(llm=scorer_llm),
258+
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
259259
)
260260
response = chain.run(
261-
User=pick_best_chain.base.BasedOn("Context"),
262-
action=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
261+
User=rl_chain.BasedOn("Context"),
262+
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
263263
)
264264
# chain llm used for both basic prompt and for scoring
265265
assert response["response"] == "hey"
@@ -269,17 +269,17 @@ def test_auto_scorer_with_user_defined_llm():
269269

270270
def test_calling_chain_w_reserved_inputs_throws():
271271
llm, PROMPT = setup()
272-
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
272+
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
273273
with pytest.raises(ValueError):
274274
chain.run(
275-
User=pick_best_chain.base.BasedOn("Context"),
276-
rl_chain_selected_based_on=pick_best_chain.base.ToSelectFrom(
275+
User=rl_chain.BasedOn("Context"),
276+
rl_chain_selected_based_on=rl_chain.ToSelectFrom(
277277
["0", "1", "2"]
278278
),
279279
)
280280

281281
with pytest.raises(ValueError):
282282
chain.run(
283-
User=pick_best_chain.base.BasedOn("Context"),
284-
rl_chain_selected=pick_best_chain.base.ToSelectFrom(["0", "1", "2"]),
283+
User=rl_chain.BasedOn("Context"),
284+
rl_chain_selected=rl_chain.ToSelectFrom(["0", "1", "2"]),
285285
)

0 commit comments

Comments
 (0)