|
14 | 14 | import typing
|
15 | 15 | from collections.abc import Iterator
|
16 | 16 | from contextlib import contextmanager
|
| 17 | +from contextlib import nullcontext |
17 | 18 | from functools import partial
|
18 | 19 | from subprocess import CompletedProcess
|
19 | 20 | from types import FunctionType
|
20 | 21 | from types import GenericAlias
|
21 | 22 | from typing import Any
|
22 | 23 | from typing import cast
|
| 24 | +from typing import ContextManager |
23 | 25 | from typing import TYPE_CHECKING
|
24 | 26 | from typing import TypedDict
|
25 | 27 |
|
@@ -248,38 +250,70 @@ def web(self) -> requests.Session:
|
248 | 250 | return requests.Session()
|
249 | 251 |
|
250 | 252 |
|
| 253 | +class DefaultVirtualenvConfig: |
| 254 | + """ |
| 255 | + Simple class to hold registered imports. |
| 256 | + """ |
| 257 | + |
| 258 | + _instance: DefaultVirtualenvConfig | None = None |
| 259 | + venv_config: VirtualEnvConfig |
| 260 | + |
| 261 | + def __new__(cls): |
| 262 | + """ |
| 263 | + Method that instantiates a singleton class and returns it. |
| 264 | + """ |
| 265 | + if cls._instance is None: |
| 266 | + instance = super().__new__(cls) |
| 267 | + cls._instance = instance |
| 268 | + return cls._instance |
| 269 | + |
| 270 | + @classmethod |
| 271 | + def set_default_venv_config(cls, venv_config: VirtualEnvConfig) -> None: |
| 272 | + """ |
| 273 | + Register an import. |
| 274 | + """ |
| 275 | + instance = cls._instance |
| 276 | + if instance is None: |
| 277 | + instance = cls() |
| 278 | + if venv_config and "name" not in venv_config: |
| 279 | + venv_config["name"] = "default" |
| 280 | + instance.venv_config = venv_config |
| 281 | + |
| 282 | + |
251 | 283 | class RegisteredImports:
|
252 | 284 | """
|
253 | 285 | Simple class to hold registered imports.
|
254 | 286 | """
|
255 | 287 |
|
256 | 288 | _instance: RegisteredImports | None = None
|
257 |
| - _registered_imports: list[str] |
| 289 | + _registered_imports: dict[str, VirtualEnvConfig | None] |
258 | 290 |
|
259 | 291 | def __new__(cls):
|
260 | 292 | """
|
261 | 293 | Method that instantiates a singleton class and returns it.
|
262 | 294 | """
|
263 | 295 | if cls._instance is None:
|
264 | 296 | instance = super().__new__(cls)
|
265 |
| - instance._registered_imports = [] |
| 297 | + instance._registered_imports = {} |
266 | 298 | cls._instance = instance
|
267 | 299 | return cls._instance
|
268 | 300 |
|
269 | 301 | @classmethod
|
270 |
| - def register_import(cls, import_module: str) -> None: |
| 302 | + def register_import( |
| 303 | + cls, import_module: str, venv_config: VirtualEnvConfig | None = None |
| 304 | + ) -> None: |
271 | 305 | """
|
272 | 306 | Register an import.
|
273 | 307 | """
|
274 | 308 | instance = cls()
|
275 | 309 | if import_module not in instance._registered_imports:
|
276 |
| - instance._registered_imports.append(import_module) |
| 310 | + instance._registered_imports[import_module] = venv_config |
277 | 311 |
|
278 | 312 | def __iter__(self):
|
279 | 313 | """
|
280 | 314 | Return an iterator of all registered imports.
|
281 | 315 | """
|
282 |
| - return iter(self._registered_imports) |
| 316 | + return iter(self._registered_imports.items()) |
283 | 317 |
|
284 | 318 |
|
285 | 319 | class Parser:
|
@@ -370,14 +404,29 @@ def __new__(cls):
|
370 | 404 | return cls._instance
|
371 | 405 |
|
372 | 406 | def _process_registered_tool_modules(self):
|
373 |
| - for module_name in RegisteredImports(): |
374 |
| - try: |
375 |
| - importlib.import_module(module_name) |
376 |
| - except ImportError as exc: |
377 |
| - if os.environ.get("TOOLS_IGNORE_IMPORT_ERRORS", "0") == "0": |
378 |
| - self.context.warn( |
379 |
| - f"Could not import the registered tools module {module_name!r}: {exc}" |
380 |
| - ) |
| 407 | + default_venv: VirtualEnv | ContextManager[None] |
| 408 | + default_venv_config = DefaultVirtualenvConfig().venv_config |
| 409 | + if default_venv_config: |
| 410 | + default_venv = VirtualEnv(ctx=self.context, **default_venv_config) |
| 411 | + else: |
| 412 | + default_venv = nullcontext() |
| 413 | + with default_venv: |
| 414 | + for module_name, venv_config in RegisteredImports(): |
| 415 | + venv: VirtualEnv | ContextManager[None] |
| 416 | + if venv_config: |
| 417 | + if "name" not in venv_config: |
| 418 | + venv_config["name"] = module_name |
| 419 | + venv = VirtualEnv(ctx=self.context, **venv_config) |
| 420 | + else: |
| 421 | + venv = nullcontext() |
| 422 | + with venv: |
| 423 | + try: |
| 424 | + importlib.import_module(module_name) |
| 425 | + except ImportError as exc: |
| 426 | + if os.environ.get("TOOLS_IGNORE_IMPORT_ERRORS", "0") == "0": |
| 427 | + self.context.warn( |
| 428 | + f"Could not import the registered tools module {module_name!r}: {exc}" |
| 429 | + ) |
381 | 430 |
|
382 | 431 | def parse_args(self):
|
383 | 432 | """
|
@@ -471,6 +520,8 @@ def __init__(self, name, help, description=None, parent=None, venv_config=None):
|
471 | 520 | GroupReference.add_command(tuple(parent + [name]), self)
|
472 | 521 | parent = GroupReference()[tuple(parent)]
|
473 | 522 |
|
| 523 | + if venv_config and "name" not in venv_config: |
| 524 | + venv_config["name"] = self.name |
474 | 525 | self.venv_config = venv_config or {}
|
475 | 526 | self.parser = parent.subparsers.add_parser(
|
476 | 527 | name.replace("_", "-"),
|
@@ -634,22 +685,22 @@ def __call__(self, func, options, venv_config: VirtualEnvConfig | None = None):
|
634 | 685 | kwargs[name] = getattr(options, name)
|
635 | 686 |
|
636 | 687 | bound = signature.bind_partial(*args, **kwargs)
|
637 |
| - venv = None |
| 688 | + venv: VirtualEnv | ContextManager[None] |
638 | 689 | if venv_config:
|
639 |
| - venv_name = getattr(options, f"{self.name}_command") |
640 |
| - venv = VirtualEnv(name=f"{self.name}.{venv_name}", ctx=self.context, **venv_config) |
| 690 | + if "name" not in venv_config: |
| 691 | + venv_config["name"] = getattr(options, f"{self.name}_command") |
| 692 | + venv = VirtualEnv(ctx=self.context, **venv_config) |
641 | 693 | elif self.venv_config:
|
642 |
| - venv = VirtualEnv(name=self.name, ctx=self.context, **self.venv_config) |
643 |
| - if venv: |
644 |
| - with venv: |
645 |
| - previous_venv = self.context.venv |
646 |
| - try: |
647 |
| - self.context.venv = venv |
648 |
| - func(self.context, *bound.args, **bound.kwargs) |
649 |
| - finally: |
650 |
| - self.context.venv = previous_venv |
| 694 | + venv = VirtualEnv(ctx=self.context, **self.venv_config) |
651 | 695 | else:
|
652 |
| - func(self.context, *bound.args, **bound.kwargs) |
| 696 | + venv = nullcontext() |
| 697 | + with venv: |
| 698 | + previous_venv = self.context.venv |
| 699 | + try: |
| 700 | + self.context.venv = venv |
| 701 | + func(self.context, *bound.args, **bound.kwargs) |
| 702 | + finally: |
| 703 | + self.context.venv = previous_venv |
653 | 704 |
|
654 | 705 |
|
655 | 706 | def command_group(
|
|
0 commit comments