Skip to content

Commit a9d774b

Browse files
ydcjeffvfdev-5trsvchn
authored
feat(app): add an option to include test file (#97)
* Merge dev to master for v0.1.0 release (#94) * Making Text classification template similar to Image Classification (#92) * Fix template sidebar * Code style * [skip ci] updated contributing and readme (#93) * [skip ci] updated contributing * [skip ci] Updated readme Co-authored-by: Taras Savchyn <[email protected]> * feat(app): add an option to include test file * fix: test_all in generate, pytest as optional dep Co-authored-by: vfdev <[email protected]> Co-authored-by: Taras Savchyn <[email protected]>
1 parent cce6dfd commit a9d774b

File tree

12 files changed

+31
-7
lines changed

12 files changed

+31
-7
lines changed

app/codegen.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,13 @@ def render_templates(self, template_name: str, config: dict):
2828
trim_blocks=True,
2929
lstrip_blocks=True,
3030
)
31-
for fname in env.list_templates(filter_func=lambda x: not x.startswith("_")):
31+
32+
def filter_func(x: str):
33+
if config["test_all"]:
34+
return not x.startswith("_")
35+
return not x.startswith("_") ^ (x == "test_all.py")
36+
37+
for fname in env.list_templates(filter_func=filter_func):
3238
code = env.get_template(fname).render(**config)
3339
self.rendered_code[fname] = code
3440
yield fname, code

templates/_base/_sidebar.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,16 @@ def ignite_loggers_options(config):
106106
st.markdown("---")
107107

108108

109+
def test_all_options(config):
110+
st.markdown("## Unit Test Options")
111+
config["test_all"] = st.checkbox(
112+
"Include a test file for the generated codes", help="Tests are implemented with pytest."
113+
)
114+
if config["test_all"]:
115+
config["test_deps"] = "pytest"
116+
st.markdown("---")
117+
118+
109119
def _setup_common_training_handlers_options(config):
110120
config["save_every_iters"] = st.number_input(
111121
"Saving iteration interval (save_every_iters)", min_value=1, value=1000

templates/gan/_sidebar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
distributed_options,
1010
ignite_handlers_options,
1111
ignite_loggers_options,
12+
test_all_options,
1213
)
1314

1415

@@ -79,5 +80,6 @@ def get_configs() -> dict:
7980
distributed_options(config)
8081
ignite_handlers_options(config)
8182
ignite_loggers_options(config)
83+
test_all_options(config)
8284

8385
return config

templates/gan/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ pytorch-ignite>=0.4.4
33
torchvision>=0.8.0
44
matplotlib>=3.3.0
55
pandas
6-
pytest
6+
{{ test_deps }}
77
{{ handler_deps }}
88
{{ logger_deps }}

templates/image_classification/_sidebar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
distributed_options,
1010
ignite_handlers_options,
1111
ignite_loggers_options,
12+
test_all_options,
1213
)
1314

1415

@@ -95,5 +96,6 @@ def get_configs():
9596
distributed_options(config)
9697
ignite_handlers_options(config)
9798
ignite_loggers_options(config)
99+
test_all_options(config)
98100

99101
return config
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
torch>=1.7.0
22
pytorch-ignite>=0.4.4
3-
pytest
3+
{{ test_deps }}
44
{{ handler_deps }}
55
{{ logger_deps }}

templates/image_classification/trainers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import torch
77
from ignite.engine import Engine
8-
from ignite.metrics import loss
98
from torch.cuda.amp import autocast
109
from torch.optim.optimizer import Optimizer
1110

templates/single/_sidebar.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
distributed_options,
1010
ignite_handlers_options,
1111
ignite_loggers_options,
12+
test_all_options,
1213
)
1314

1415

@@ -25,5 +26,6 @@ def get_configs():
2526
distributed_options(config)
2627
ignite_handlers_options(config)
2728
ignite_loggers_options(config)
29+
test_all_options(config)
2830

2931
return config

templates/single/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
setuptools
22
torch>=1.7.0
33
pytorch-ignite>=0.4.4
4-
pytest
4+
{{ test_deps }}
55
{{ handler_deps }}
66
{{ logger_deps }}

templates/text_classification/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,6 @@ torch>=1.7.0
22
pytorch-ignite>=0.4.4
33
transformers
44
datasets
5+
{{ test_deps }}
56
{{ handler_deps }}
67
{{ logger_deps }}

0 commit comments

Comments
 (0)