Skip to content
This repository was archived by the owner on Apr 28, 2021. It is now read-only.

Commit 4dcd932

Browse files
authored
feat: model auto load after training (#49)
1 parent d4fc8f8 commit 4dcd932

File tree

5 files changed

+21
-5
lines changed

5 files changed

+21
-5
lines changed

docker/Dockerfile.botfront

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,5 +78,8 @@ SHELL ["/bin/bash", "-o", "pipefail", "-c"]
7878

7979
# the entry point
8080
EXPOSE 5005
81-
ENTRYPOINT ["bash"]
82-
CMD ["-c", "rasa run -m models/model-$BF_PROJECT_ID.tar.gz $([ -z \"$AUTH_TOKEN\" ] && echo \"\" || echo \"--auth-token $AUTH_TOKEN\" ) --enable-api --debug"]
81+
82+
CMD rasa run \
83+
$([ -n "$MODEL_PATH" ] && echo "-m $MODEL_PATH") \
84+
$([ -n "$AUTH_TOKEN" ] && echo "--auth-token $AUTH_TOKEN" ) \
85+
--enable-api --debug

rasa/server.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1006,10 +1006,12 @@ async def train(request: Request, temporary_directory: Path) -> HTTPResponse:
10061006
"train your model.",
10071007
)
10081008

1009+
load_model_after = request.args.get("load_model_after", False)
10091010
if request.headers.get("Content-type") == YAML_CONTENT_TYPE:
10101011
training_payload = _training_payload_from_yaml(request, temporary_directory)
10111012
else:
10121013
training_payload = _training_payload_from_json(request, temporary_directory)
1014+
load_model_after = request.json.get("load_model_after", load_model_after)
10131015

10141016
try:
10151017
with app.active_training_processes.get_lock():
@@ -1023,6 +1025,15 @@ async def train(request: Request, temporary_directory: Path) -> HTTPResponse:
10231025
if training_result.model:
10241026
filename = os.path.basename(training_result.model)
10251027

1028+
if load_model_after is True:
1029+
app.agent = await _load_agent(
1030+
training_result.model,
1031+
endpoints=endpoints,
1032+
lock_store=app.agent.lock_store,
1033+
)
1034+
1035+
logger.debug(f"Successfully loaded model '{filename}'.")
1036+
10261037
return await response.file(
10271038
training_result.model,
10281039
filename=filename,
@@ -1648,7 +1659,7 @@ def _training_payload_from_json(
16481659
domain=domain_path,
16491660
config=config_paths, # bf
16501661
training_files=str(temp_dir),
1651-
output=model_output_directory,
1662+
output=os.environ.get("MODEL_PATH", DEFAULT_MODELS_PATH), # bf
16521663
force_training=request_payload.get(
16531664
"force", request.args.get("force_training", False)
16541665
),

rasa/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# this file will automatically be changed,
22
# do not add anything but the version number here!
33
__version__ = "2.3.3"
4-
__bf_patch__ = "-bf.2"
4+
__bf_patch__ = "-bf.3"

rasa_addons/importers/botfront.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def __init__(
4747
}
4848

4949
def path_for_nlu_lang(self, lang) -> List[Text]:
50+
if len(self.nlu_config.keys()) < 2:
51+
return self._nlu_files
5052
return [x for x in self._nlu_files if f"nlu/{lang}" in x or f"nlu-{lang}" in x]
5153

5254
async def get_config(self) -> Dict:

rasa_addons/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
setup(
44
name="rasa_addons",
5-
version="2.3.3.2",
5+
version="2.3.3.3",
66
author="Botfront",
77
description="Rasa Addons - Components for Rasa and Botfront",
88
install_requires=[

0 commit comments

Comments
 (0)