Skip to content

Commit a6f9dcc

Browse files
committed
rename rl_chain_base to base and update paths and imports
1 parent b422dc0 commit a6f9dcc

File tree

6 files changed

+66
-64
lines changed

6 files changed

+66
-64
lines changed

libs/langchain/langchain/chains/rl_chain/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from langchain.chains.rl_chain.pick_best_chain import PickBest
2-
from langchain.chains.rl_chain.rl_chain_base import (
2+
from langchain.chains.rl_chain.base import (
33
Embed,
44
BasedOn,
55
ToSelectFrom,

libs/langchain/langchain/chains/rl_chain/pick_best_chain.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
import langchain.chains.rl_chain.rl_chain_base as base
3+
import langchain.chains.rl_chain.base as base
44

55
from langchain.callbacks.manager import CallbackManagerForChainRun
66
from langchain.chains.base import Chain

libs/langchain/tests/unit_tests/chains/rl_chain/test_pick_best_chain_call.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import langchain.chains.rl_chain as rl_chain
1+
import langchain.chains.rl_chain.pick_best_chain as pick_best_chain
2+
import langchain.chains.rl_chain.base as rl_chain
23
from test_utils import MockEncoder
34
import pytest
45
from langchain.prompts.prompt import PromptTemplate
@@ -17,7 +18,7 @@ def setup():
1718

1819
def test_multiple_ToSelectFrom_throws():
1920
llm, PROMPT = setup()
20-
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
21+
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
2122
actions = ["0", "1", "2"]
2223
with pytest.raises(ValueError):
2324
chain.run(
@@ -29,15 +30,15 @@ def test_multiple_ToSelectFrom_throws():
2930

3031
def test_missing_basedOn_from_throws():
3132
llm, PROMPT = setup()
32-
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
33+
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
3334
actions = ["0", "1", "2"]
3435
with pytest.raises(ValueError):
3536
chain.run(action=rl_chain.ToSelectFrom(actions))
3637

3738

3839
def test_ToSelectFrom_not_a_list_throws():
3940
llm, PROMPT = setup()
40-
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
41+
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
4142
actions = {"actions": ["0", "1", "2"]}
4243
with pytest.raises(ValueError):
4344
chain.run(
@@ -50,7 +51,7 @@ def test_update_with_delayed_score_with_auto_validator_throws():
5051
llm, PROMPT = setup()
5152
# this LLM returns a number so that the auto validator will return that
5253
auto_val_llm = FakeListChatModel(responses=["3"])
53-
chain = rl_chain.PickBest.from_llm(
54+
chain = pick_best_chain.PickBest.from_llm(
5455
llm=llm,
5556
prompt=PROMPT,
5657
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
@@ -71,7 +72,7 @@ def test_update_with_delayed_score_force():
7172
llm, PROMPT = setup()
7273
# this LLM returns a number so that the auto validator will return that
7374
auto_val_llm = FakeListChatModel(responses=["3"])
74-
chain = rl_chain.PickBest.from_llm(
75+
chain = pick_best_chain.PickBest.from_llm(
7576
llm=llm,
7677
prompt=PROMPT,
7778
selection_scorer=rl_chain.AutoSelectionScorer(llm=auto_val_llm),
@@ -92,7 +93,7 @@ def test_update_with_delayed_score_force():
9293

9394
def test_update_with_delayed_score():
9495
llm, PROMPT = setup()
95-
chain = rl_chain.PickBest.from_llm(
96+
chain = pick_best_chain.PickBest.from_llm(
9697
llm=llm, prompt=PROMPT, selection_scorer=None
9798
)
9899
actions = ["0", "1", "2"]
@@ -115,7 +116,7 @@ def score_response(self, inputs, llm_response: str) -> float:
115116
score = 200
116117
return score
117118

118-
chain = rl_chain.PickBest.from_llm(
119+
chain = pick_best_chain.PickBest.from_llm(
119120
llm=llm, prompt=PROMPT, selection_scorer=CustomSelectionScorer()
120121
)
121122
actions = ["0", "1", "2"]
@@ -130,8 +131,8 @@ def score_response(self, inputs, llm_response: str) -> float:
130131

131132
def test_default_embeddings():
132133
llm, PROMPT = setup()
133-
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
134-
chain = rl_chain.PickBest.from_llm(
134+
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
135+
chain = pick_best_chain.PickBest.from_llm(
135136
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
136137
)
137138

@@ -163,8 +164,8 @@ def test_default_embeddings():
163164

164165
def test_default_embeddings_off():
165166
llm, PROMPT = setup()
166-
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
167-
chain = rl_chain.PickBest.from_llm(
167+
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
168+
chain = pick_best_chain.PickBest.from_llm(
168169
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder, auto_embed=False
169170
)
170171

@@ -188,8 +189,8 @@ def test_default_embeddings_off():
188189

189190
def test_default_embeddings_mixed_w_explicit_user_embeddings():
190191
llm, PROMPT = setup()
191-
feature_embedder = rl_chain.PickBestFeatureEmbedder(model=MockEncoder())
192-
chain = rl_chain.PickBest.from_llm(
192+
feature_embedder = pick_best_chain.PickBestFeatureEmbedder(model=MockEncoder())
193+
chain = pick_best_chain.PickBest.from_llm(
193194
llm=llm, prompt=PROMPT, feature_embedder=feature_embedder
194195
)
195196

@@ -223,7 +224,7 @@ def test_default_embeddings_mixed_w_explicit_user_embeddings():
223224
def test_default_no_scorer_specified():
224225
_, PROMPT = setup()
225226
chain_llm = FakeListChatModel(responses=[100])
226-
chain = rl_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
227+
chain = pick_best_chain.PickBest.from_llm(llm=chain_llm, prompt=PROMPT)
227228
response = chain.run(
228229
User=rl_chain.BasedOn("Context"),
229230
action=rl_chain.ToSelectFrom(["0", "1", "2"]),
@@ -236,7 +237,7 @@ def test_default_no_scorer_specified():
236237

237238
def test_explicitly_no_scorer():
238239
llm, PROMPT = setup()
239-
chain = rl_chain.PickBest.from_llm(
240+
chain = pick_best_chain.PickBest.from_llm(
240241
llm=llm, prompt=PROMPT, selection_scorer=None
241242
)
242243
response = chain.run(
@@ -252,7 +253,7 @@ def test_explicitly_no_scorer():
252253
def test_auto_scorer_with_user_defined_llm():
253254
llm, PROMPT = setup()
254255
scorer_llm = FakeListChatModel(responses=[300])
255-
chain = rl_chain.PickBest.from_llm(
256+
chain = pick_best_chain.PickBest.from_llm(
256257
llm=llm,
257258
prompt=PROMPT,
258259
selection_scorer=rl_chain.AutoSelectionScorer(llm=scorer_llm),
@@ -269,7 +270,7 @@ def test_auto_scorer_with_user_defined_llm():
269270

270271
def test_calling_chain_w_reserved_inputs_throws():
271272
llm, PROMPT = setup()
272-
chain = rl_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
273+
chain = pick_best_chain.PickBest.from_llm(llm=llm, prompt=PROMPT)
273274
with pytest.raises(ValueError):
274275
chain.run(
275276
User=rl_chain.BasedOn("Context"),

0 commit comments

Comments
 (0)