Skip to content
This repository was archived by the owner on Sep 13, 2023. It is now read-only.

Commit 0939e5d

Browse files
authored
Fix declaring classes with abc fields with defaults (#625)
1 parent daf022e commit 0939e5d

File tree

3 files changed

+105
-18
lines changed

3 files changed

+105
-18
lines changed

mlem/cli/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _add_examples(
3333
allow_none=False,
3434
default=None,
3535
root_cls=root_cls,
36+
force_not_set=False,
3637
),
3738
root_cls=root_cls,
3839
parent_help=f"Element of {field.path}",

mlem/cli/utils.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -198,26 +198,34 @@ def _get_type_name_alias(type_):
198198
return type_.__name__ if type_ is not None else "any"
199199

200200

201-
def anything(type_):
201+
def anything(type_, allow_none: bool):
202202
"""Creates special type that is named as original type or collection type
203203
It returns original object on creation and is needed for nice typename in cli option help
204204
"""
205-
return type(
206-
_get_type_name_alias(type_), (), {"__new__": lambda cls, value: value}
207-
)
205+
206+
def new(cls, value): # pylint: disable=unused-argument
207+
"""Just return the value"""
208+
if allow_none and value == "None":
209+
return None
210+
return value
211+
212+
return type(_get_type_name_alias(type_), (), {"__new__": new})
208213

209214

210215
def optional(type_):
211216
"""Creates special type that is named as original type or collection type
212217
It allows use string `None` to indicate None value"""
218+
219+
def new(cls, value): # pylint: disable=unused-argument
220+
"""Check if value is string None"""
221+
if value == "None":
222+
return None
223+
return type_(value)
224+
213225
return type(
214226
_get_type_name_alias(type_),
215227
(),
216-
{
217-
"__new__": lambda cls, value: None
218-
if value == "None"
219-
else type_(value)
220-
},
228+
{"__new__": new},
221229
)
222230

223231

@@ -231,6 +239,7 @@ def parse_type_field(
231239
allow_none: bool,
232240
default: Any,
233241
root_cls: Type[BaseModel],
242+
force_not_set: bool,
234243
) -> Iterator[CliTypeField]:
235244
"""Recursively creates CliTypeFields from field description"""
236245
if is_list or is_mapping:
@@ -278,7 +287,7 @@ def parse_type_field(
278287
allow_none=allow_none,
279288
path=path,
280289
type_=type_,
281-
default=default,
290+
default=default if not force_not_set else NOT_SET,
282291
help=help_,
283292
is_list=is_list,
284293
is_mapping=is_mapping,
@@ -333,22 +342,24 @@ def iterate_type_fields(
333342
display_as_type(field_type), __root__=(field_type, ...)
334343
)
335344
if field_type is Any:
336-
field_type = anything(field_type)
345+
field_type = anything(field_type, field.allow_none)
337346

338347
if not isinstance(field_type, type):
339348
# skip too complicated stuff
340349
continue
341350

351+
required = not force_not_req and bool(field.required)
342352
yield from parse_type_field(
343353
path=fullname,
344354
type_=field_type,
345355
help_=get_field_help(cls, name),
346356
is_list=field.shape in LIST_LIKE_SHAPES,
347357
is_mapping=field.shape in MAPPING_LIKE_SHAPES,
348-
required=not force_not_req and bool(field.required),
358+
required=required,
349359
allow_none=field.allow_none,
350-
default=field.default,
360+
default=field.default if required else NOT_SET,
351361
root_cls=root_cls,
362+
force_not_set=force_not_req,
352363
)
353364

354365

@@ -381,11 +392,18 @@ def _options_from_model(
381392
continue
382393
if issubclass(field.type_, MlemABC) and field.type_.__is_root__:
383394
yield from _options_from_mlem_abc(
384-
ctx, field, path, force_not_set=force_not_set
395+
ctx,
396+
field,
397+
path,
398+
force_not_set=force_not_set or field.default == NOT_SET,
385399
)
386400
continue
387401

388-
yield _option_from_field(field, path, force_not_set=force_not_set)
402+
yield _option_from_field(
403+
field,
404+
path,
405+
force_not_set=force_not_set or field.default == NOT_SET,
406+
)
389407

390408

391409
def _options_from_mlem_abc(
@@ -506,12 +524,12 @@ def _option_from_field(
506524
"""Create cli option from field descriptor"""
507525
type_ = override_type or field.type_
508526
if force_not_set:
509-
type_ = anything(type_)
527+
type_ = anything(type_, field.allow_none)
510528
elif field.allow_none:
511529
type_ = optional(type_)
512530
option = SetViaFileTyperOption(
513531
param_decls=[f"--{path}", path.replace(".", "_")],
514-
type=type_ if not force_not_set else anything(type_),
532+
type=type_ if not force_not_set else anything(type_, field.allow_none),
515533
required=field.required and not force_not_set,
516534
default=field.default
517535
if not field.is_list and not field.is_mapping and not force_not_set
@@ -531,7 +549,10 @@ def generator(ctx: CallContext):
531549
cls = load_impl_ext(mlem_abc.abs_name, type_name=type_name)
532550
except ImportError:
533551
return
534-
yield from _options_from_model(cls, ctx)
552+
yield from _options_from_model(
553+
cls,
554+
ctx,
555+
)
535556

536557
return generator
537558

tests/cli/test_declare.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from mlem.cli.declare import create_declare_mlem_object_subcommand, declare
88
from mlem.contrib.docker import DockerDirBuilder
9+
from mlem.contrib.docker.base import DockerRegistry, RemoteRegistry
910
from mlem.contrib.docker.context import DockerBuildArgs
1011
from mlem.contrib.fastapi import FastAPIServer
1112
from mlem.contrib.heroku.meta import HerokuEnv
@@ -404,6 +405,70 @@ class RootListNested(_MockBuilder):
404405
)
405406

406407

408+
class MockOptionalFieldWithNonOptionalSubfield(_MockBuilder):
409+
"""mock"""
410+
411+
f: Optional[SimpleValue] = None
412+
413+
414+
all_test_params.append(
415+
pytest.param(
416+
MockOptionalFieldWithNonOptionalSubfield(),
417+
"",
418+
id="non_optional_subfield_empty",
419+
)
420+
)
421+
all_test_params.append(
422+
pytest.param(
423+
MockOptionalFieldWithNonOptionalSubfield(f=SimpleValue(value="a")),
424+
"--f.value a",
425+
id="non_optional_subfield_full",
426+
)
427+
)
428+
429+
430+
class ThreeValues(BaseModel):
431+
value: str
432+
with_def: str = "value"
433+
opt: Optional[str] = None
434+
with_def_model: SimpleValue = SimpleValue(value="value")
435+
with_def_abc: DockerRegistry = DockerRegistry()
436+
437+
438+
class MockOptionalFieldWithOptionalAndNonOptionalSubfield(_MockBuilder):
439+
"""mock"""
440+
441+
f: Optional[ThreeValues] = None
442+
443+
444+
all_test_params.append(
445+
pytest.param(
446+
MockOptionalFieldWithOptionalAndNonOptionalSubfield(),
447+
"",
448+
id="optional_and_non_optional_subfield_empty",
449+
)
450+
)
451+
all_test_params.append(
452+
pytest.param(
453+
MockOptionalFieldWithOptionalAndNonOptionalSubfield(
454+
f=ThreeValues(value="a")
455+
),
456+
"--f.value a",
457+
id="optional_and_non_optional_subfield_full",
458+
)
459+
)
460+
461+
all_test_params.append(
462+
pytest.param(
463+
MockOptionalFieldWithOptionalAndNonOptionalSubfield(
464+
f=ThreeValues(value="a", with_def_abc=RemoteRegistry(host="aaa"))
465+
),
466+
"--f.value a --f.with_def_abc remote --f.with_def_abc.host aaa",
467+
id="optional_and_non_optional_subfield_full_abc",
468+
)
469+
)
470+
471+
407472
@lru_cache()
408473
def _declare_builder_command(type_: str):
409474
create_declare_mlem_object_subcommand(

0 commit comments

Comments
 (0)