Skip to content

Commit e985ae2

Browse files
ibro45claude
andcommitted
feat: Add _imports_ key and rename globals parameter to imports
Add _imports_ key support in YAML configs to declare imports available to all expressions: ```yaml _imports_: torch: torch np: numpy Path: pathlib.Path device: "$torch.device('cuda')" ``` Also rename the Config parameter from globals to imports for consistency. BREAKING CHANGE: Config(globals=...) is now Config(imports=...) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 41431d6 commit e985ae2

File tree

5 files changed

+230
-36
lines changed

5 files changed

+230
-36
lines changed

docs/user-guide/advanced.md

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ Sparkwheel recognizes these special keys in configuration:
117117

118118
- `_target_`: Class or function path to instantiate (e.g., `"torch.nn.Linear"`)
119119
- `_disabled_`: Skip instantiation if `true` (removed from parent). See [Instantiation](instantiation.md#_disabled_-skip-instantiation) for details.
120-
- `_requires_`: List of dependencies to evaluate/instantiate first
121120
- `_mode_`: Operating mode for instantiation (see below)
121+
- `_imports_`: Declare imports available to all expressions (see [Imports](#imports-for-expressions) below)
122122

123123
### `_mode_` - Instantiation Modes
124124

@@ -289,25 +289,56 @@ except ConfigKeyError as e:
289289

290290
Color output is auto-detected and respects `NO_COLOR` environment variable.
291291

292-
## Globals for Expressions
292+
## Imports for Expressions
293293

294-
Pre-import modules for use in expressions:
294+
Make modules available to all expressions. There are two ways to do this:
295+
296+
### Method 1: `_imports_` Key in YAML
297+
298+
Declare imports directly in your config file:
299+
300+
```yaml
301+
# config.yaml
302+
_imports_:
303+
torch: torch
304+
np: numpy
305+
Path: pathlib.Path
306+
307+
# Now use them in expressions
308+
device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
309+
data: "$np.array([1, 2, 3])"
310+
save_path: "$Path('/data/models')"
311+
```
312+
313+
The `_imports_` key is removed from the config after processing—it won't appear in your resolved config.
314+
315+
### Method 2: `imports` Parameter in Python
316+
317+
Pass imports when creating the Config:
295318

296319
```python
297320
from sparkwheel import Config
298321
299-
# Pre-import torch for all expressions
300-
config = Config(globals={"torch": "torch", "np": "numpy"})
322+
# Pre-import modules for all expressions
323+
config = Config(imports={"torch": "torch", "np": "numpy"})
301324
config.update("config.yaml")
302325
303326
# Now expressions can use torch and np without importing
304327
```
305328

306-
Example config:
329+
### Combining Both Methods
307330

308-
```yaml
309-
device: "$torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
310-
data: "$np.array([1, 2, 3])"
331+
You can use both approaches together—they merge:
332+
333+
```python
334+
from collections import Counter
335+
336+
config = Config(imports={"Counter": Counter})
337+
config.update({
338+
"_imports_": {"json": "json"},
339+
"data": '$json.dumps({"a": 1})',
340+
"counts": "$Counter([1, 1, 2])"
341+
})
311342
```
312343

313344
## Type Hints

docs/user-guide/basics.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,10 +409,10 @@ Sparkwheel reserves certain keys with special meaning:
409409

410410
- `_target_`: Specifies a class to instantiate
411411
- `_disabled_`: Skip instantiation if true
412-
- `_requires_`: Dependencies that must be resolved first
413412
- `_mode_`: Instantiation mode (default, callable, debug)
413+
- `_imports_`: Declare imports available to all expressions
414414

415-
These are covered in detail in [Instantiation Guide](instantiation.md).
415+
These are covered in detail in [Instantiation Guide](instantiation.md) and [Advanced Features](advanced.md).
416416

417417
## Common Patterns
418418

docs/user-guide/instantiation.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,6 @@ augmentation:
224224
- `_target_`: Class or function path to instantiate (required)
225225
- `_args_`: List of positional arguments to pass
226226
- `_disabled_`: Skip instantiation if `true` (removed from parent)
227-
- `_requires_`: Dependencies to resolve first
228227
- `_mode_`: Instantiation mode (`"default"`, `"callable"`, or `"debug"`)
229228

230229
For complete details, see the [Advanced Features](advanced.md) and [API Reference](../reference/).

src/sparkwheel/config.py

Lines changed: 70 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ class Config:
148148
```
149149
150150
Args:
151-
globals: Pre-imported packages for expressions (e.g., {"torch": "torch"})
151+
imports: Pre-imported packages for expressions (e.g., {"torch": "torch"})
152152
schema: Dataclass schema for continuous validation
153153
coerce: Auto-convert compatible types (default: True)
154154
strict: Reject fields not in schema (default: True)
@@ -159,7 +159,7 @@ def __init__(
159159
self,
160160
data: dict[str, Any] | None = None, # Internal/testing use only
161161
*, # Rest are keyword-only
162-
globals: dict[str, Any] | None = None,
162+
imports: dict[str, Any] | None = None,
163163
schema: type | None = None,
164164
coerce: bool = True,
165165
strict: bool = True,
@@ -171,7 +171,7 @@ def __init__(
171171
172172
Args:
173173
data: Initial data (internal/testing use only, not validated)
174-
globals: Pre-imported packages for expression evaluation
174+
imports: Pre-imported packages for expression evaluation
175175
schema: Dataclass schema for continuous validation
176176
coerce: Auto-convert compatible types
177177
strict: Reject fields not in schema
@@ -196,14 +196,14 @@ def __init__(
196196
self._strict: bool = strict
197197
self._allow_missing: bool = allow_missing
198198

199-
# Process globals (import string module paths)
200-
self._globals: dict[str, Any] = {}
201-
if isinstance(globals, dict):
202-
for k, v in globals.items():
203-
self._globals[k] = optional_import(v)[0] if isinstance(v, str) else v
199+
# Process imports (import string module paths)
200+
self._imports: dict[str, Any] = {}
201+
if isinstance(imports, dict):
202+
for k, v in imports.items():
203+
self._imports[k] = optional_import(v)[0] if isinstance(v, str) else v
204204

205205
self._loader = Loader()
206-
self._preprocessor = Preprocessor(self._loader, self._globals)
206+
self._preprocessor = Preprocessor(self._loader, self._imports)
207207

208208
def get(self, id: str = "", default: Any = None) -> Any:
209209
"""Get raw config value (unresolved).
@@ -683,6 +683,10 @@ def _parse(self, reset: bool = True) -> None:
683683
if reset:
684684
self._resolver.reset()
685685

686+
# Process _imports_ key if present in config data
687+
# This allows YAML-based imports that become available to all expressions
688+
self._process_imports_key()
689+
686690
# Phase 2: Expand local raw references (%key) now that all composition is complete
687691
# CLI overrides have been applied, so local refs will see final values
688692
self._data = self._preprocessor.process_raw_refs(
@@ -693,14 +697,60 @@ def _parse(self, reset: bool = True) -> None:
693697
self._data = self._preprocessor.process(self._data, self._data, id="")
694698

695699
# Stage 2: Parse config tree to create Items
696-
parser = Parser(globals=self._globals, metadata=self._locations)
700+
parser = Parser(globals=self._imports, metadata=self._locations)
697701
items = parser.parse(self._data)
698702

699703
# Stage 3: Add items to resolver
700704
self._resolver.add_items(items)
701705

702706
self._is_parsed = True
703707

708+
def _process_imports_key(self) -> None:
709+
"""Process _imports_ key from config data.
710+
711+
The _imports_ key allows declaring imports directly in YAML:
712+
713+
```yaml
714+
_imports_:
715+
torch: torch
716+
np: numpy
717+
Path: pathlib.Path
718+
719+
model:
720+
device: "$torch.device('cuda')"
721+
```
722+
723+
These imports become available to all expressions in the config.
724+
The _imports_ key is removed from the data after processing.
725+
"""
726+
imports_key = "_imports_"
727+
if imports_key not in self._data:
728+
return
729+
730+
imports_config = self._data.pop(imports_key)
731+
if not isinstance(imports_config, dict):
732+
return
733+
734+
# Process each import
735+
for name, module_path in imports_config.items():
736+
if isinstance(module_path, str):
737+
# Handle dotted paths like "pathlib.Path" or "collections.Counter"
738+
# Split into module and attribute if needed
739+
if "." in module_path:
740+
parts = module_path.rsplit(".", 1)
741+
# First try as a module (e.g., "os.path")
742+
module_obj, success = optional_import(module_path)
743+
if not success:
744+
# Try as module.attribute (e.g., "pathlib.Path")
745+
module_obj, success = optional_import(parts[0], name=parts[1])
746+
self._imports[name] = module_obj
747+
else:
748+
# Simple module name like "json"
749+
self._imports[name] = optional_import(module_path)[0]
750+
else:
751+
# Already a module or callable
752+
self._imports[name] = module_path
753+
704754
def _get_by_id(self, id: str) -> Any:
705755
"""Get config value by ID path.
706756
@@ -789,7 +839,8 @@ def parse_overrides(args: list[str]) -> dict[str, Any]:
789839
"""Parse CLI argument overrides with automatic type inference.
790840
791841
Supports only key=value syntax with operator prefixes.
792-
Types are automatically inferred using ast.literal_eval().
842+
Values are parsed using YAML syntax (via ``yaml.safe_load``), ensuring
843+
CLI overrides behave identically to values in YAML config files.
793844
794845
Args:
795846
args: List of argument strings to parse (e.g., from argparse)
@@ -805,21 +856,24 @@ def parse_overrides(args: list[str]) -> dict[str, Any]:
805856
806857
Examples:
807858
>>> # Basic overrides (compose/merge)
808-
>>> parse_overrides(["model::lr=0.001", "debug=True"])
859+
>>> parse_overrides(["model::lr=0.001", "debug=true"])
809860
{"model::lr": 0.001, "debug": True}
810861
811862
>>> # With operators
812-
>>> parse_overrides(["=model={'_target_': 'ResNet'}", "~old_param"])
863+
>>> parse_overrides(["=model={_target_: ResNet}", "~old_param"])
813864
{"=model": {'_target_': 'ResNet'}, "~old_param": None}
814865
815866
>>> # Nested paths with operators
816867
>>> parse_overrides(["=optimizer::lr=0.01", "~model::old_param"])
817868
{"=optimizer::lr": 0.01, "~model::old_param": None}
818869
819870
Note:
820-
The '=' character serves dual purpose:
821-
- In 'key=value' → assignment operator (CLI syntax)
822-
- In '=key=value' → replace operator prefix (config operator)
871+
- The '=' character serves dual purpose:
872+
- In 'key=value' → assignment operator (CLI syntax)
873+
- In '=key=value' → replace operator prefix (config operator)
874+
- Values use YAML syntax: ``true``/``false``, ``yes``/``no``, ``on``/``off``
875+
for booleans, ``null`` or ``~`` for None, ``{key: value}`` for dicts.
876+
- Python's ``None`` is parsed as the string ``"None"`` (use ``null`` instead).
823877
"""
824878
import yaml
825879

tests/test_config.py

Lines changed: 118 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -134,17 +134,127 @@ def test_init_with_none(self):
134134
assert isinstance(parser._data, dict)
135135
assert parser._data == {}
136136

137-
def test_init_with_globals_dict(self):
138-
"""Test Config init with globals dict."""
139-
parser = Config({}, globals={"pd": "pandas"})
140-
assert "pd" in parser._globals
137+
def test_init_with_imports_dict(self):
138+
"""Test Config init with imports dict."""
139+
parser = Config({}, imports={"pd": "pandas"})
140+
assert "pd" in parser._imports
141141

142-
def test_init_with_globals_callable(self):
143-
"""Test Config init with globals containing callables."""
142+
def test_init_with_imports_callable(self):
143+
"""Test Config init with imports containing callables."""
144144
from collections import Counter
145145

146-
parser = Config({}, globals={"Counter": Counter})
147-
assert parser._globals["Counter"] is Counter
146+
parser = Config({}, imports={"Counter": Counter})
147+
assert parser._imports["Counter"] is Counter
148+
149+
150+
class TestConfigImports:
151+
"""Test _imports_ key handling."""
152+
153+
def test_imports_key_basic(self):
154+
"""Test _imports_ key makes modules available to expressions."""
155+
config = Config().update(
156+
{
157+
"_imports_": {"json": "json"},
158+
"data": '$json.dumps({"a": 1})',
159+
}
160+
)
161+
result = config.resolve("data")
162+
assert result == '{"a": 1}'
163+
164+
def test_imports_key_multiple_modules(self):
165+
"""Test _imports_ with multiple modules."""
166+
config = Config().update(
167+
{
168+
"_imports_": {
169+
"os": "os",
170+
"Path": "pathlib.Path",
171+
},
172+
"sep": "$os.sep",
173+
"path_type": "$Path",
174+
}
175+
)
176+
import os
177+
from pathlib import Path
178+
179+
assert config.resolve("sep") == os.sep
180+
assert config.resolve("path_type") is Path
181+
182+
def test_imports_key_removed_from_data(self):
183+
"""Test _imports_ key is removed from config data after processing."""
184+
config = Config().update(
185+
{
186+
"_imports_": {"json": "json"},
187+
"data": '$json.dumps({"a": 1})',
188+
}
189+
)
190+
config.resolve() # Trigger parsing
191+
assert "_imports_" not in config._data
192+
193+
def test_imports_key_combined_with_imports_parameter(self):
194+
"""Test _imports_ key works with imports parameter."""
195+
from collections import Counter
196+
197+
config = Config(imports={"Counter": Counter}).update(
198+
{
199+
"_imports_": {"json": "json"},
200+
"counter": "$Counter([1, 1, 2])",
201+
"data": '$json.dumps({"a": 1})',
202+
}
203+
)
204+
assert config.resolve("counter") == Counter([1, 1, 2])
205+
assert config.resolve("data") == '{"a": 1}'
206+
207+
def test_imports_key_invalid_value_ignored(self):
208+
"""Test _imports_ with invalid value is ignored gracefully."""
209+
config = Config().update(
210+
{
211+
"_imports_": "not a dict",
212+
"value": 42,
213+
}
214+
)
215+
result = config.resolve("value")
216+
assert result == 42
217+
218+
def test_imports_key_with_dotted_class_path(self):
219+
"""Test _imports_ with dotted path to a class (e.g., pathlib.Path)."""
220+
from collections import Counter
221+
222+
config = Config().update(
223+
{
224+
"_imports_": {"Counter": "collections.Counter"},
225+
"counts": "$Counter([1, 1, 2, 2, 2])",
226+
}
227+
)
228+
result = config.resolve("counts")
229+
assert result == Counter([1, 1, 2, 2, 2])
230+
231+
def test_imports_key_with_dotted_module_path(self):
232+
"""Test _imports_ with dotted path to a submodule (e.g., os.path)."""
233+
import os.path
234+
235+
config = Config().update(
236+
{
237+
"_imports_": {"ospath": "os.path"},
238+
"sep": "$ospath.sep",
239+
}
240+
)
241+
result = config.resolve("sep")
242+
assert result == os.path.sep
243+
244+
def test_imports_key_with_non_string_value(self):
245+
"""Test _imports_ with non-string value (already imported module)."""
246+
import json
247+
248+
# Pass the module directly via imports parameter, then use _imports_ with non-string
249+
# Note: Can't put module in _imports_ dict in update() due to deepcopy,
250+
# so we test via direct _data manipulation before parse
251+
config = Config()
252+
config._data = {
253+
"_imports_": {"my_json": json},
254+
"data": '$my_json.dumps({"a": 1})',
255+
}
256+
result = config.resolve("data")
257+
assert result == '{"a": 1}'
148258

149259

150260
class TestConfigReferences:

0 commit comments

Comments
 (0)