Skip to content

Commit 38d5949

Browse files
committed
misc tidy
1 parent 64b7ca0 commit 38d5949

File tree

2 files changed

+40
-31
lines changed

2 files changed

+40
-31
lines changed

shtab/main.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,23 @@
11
import argparse
2-
import contextlib
32
import logging
43
import os
54
import sys
5+
from contextlib import contextmanager
66
from importlib import import_module
77
from pathlib import Path
8-
from typing import Generator
9-
from typing import Optional as Opt
10-
from typing import TextIO
118

129
from . import SUPPORTED_SHELLS, __version__, add_argument_to, complete
1310

1411
log = logging.getLogger(__name__)
1512

1613

17-
@contextlib.contextmanager
18-
def extract_stdout(output: Opt[Path]) -> Generator[TextIO, None, None]:
19-
if output is None:
20-
yield sys.stdout
21-
else:
22-
with output.open("w") as stdout:
23-
yield stdout
24-
25-
2614
def get_main_parser():
2715
parser = argparse.ArgumentParser(prog="shtab")
2816
parser.add_argument("parser", help="importable parser (or function returning parser)")
2917
parser.add_argument("--version", action="version", version="%(prog)s " + __version__)
3018
parser.add_argument("-s", "--shell", default=SUPPORTED_SHELLS[0], choices=SUPPORTED_SHELLS)
31-
parser.add_argument("-o", "--output", help="write output to file instead of stdout", type=Path)
19+
parser.add_argument("-o", "--output", default='-', help="output file (- for stdout)",
20+
type=Path)
3221
parser.add_argument("--prefix", help="prepended to generated functions to avoid clashes")
3322
parser.add_argument("--preamble", help="prepended to generated script")
3423
parser.add_argument("--prog", help="custom program name (overrides `parser.prog`)")
@@ -67,7 +56,16 @@ def main(argv=None):
6756
other_parser = other_parser()
6857
if args.prog:
6958
other_parser.prog = args.prog
70-
with extract_stdout(args.output) as stdout:
59+
60+
@contextmanager
61+
def _open(out_path):
62+
if str(out_path) in ("-", "stdout"):
63+
yield sys.stdout
64+
else:
65+
with out_path.open('w') as fd:
66+
yield fd
67+
68+
with _open(args.output) as fd:
7169
print(
7270
complete(other_parser, shell=args.shell, root_prefix=args.prefix
73-
or args.parser.split(".", 1)[0], preamble=args.preamble), file=stdout)
71+
or args.parser.split(".", 1)[0], preamble=args.preamble), file=fd)

tests/test_shtab.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010

1111
import shtab
12-
from shtab.main import extract_stdout, get_main_parser, main
12+
from shtab.main import get_main_parser, main
1313

1414
fix_shell = pytest.mark.parametrize("shell", shtab.SUPPORTED_SHELLS)
1515

@@ -82,6 +82,31 @@ def test_main_self_completion(shell, caplog, capsys):
8282
assert not caplog.record_tuples
8383

8484

85+
@pytest.mark.parametrize('output', ["-", "stdout", "test.txt"])
86+
@fix_shell
87+
def test_main_output_path(shell, caplog, capsys, change_dir, output):
88+
assert not capsys.readouterr().out
89+
with caplog.at_level(logging.INFO):
90+
try:
91+
main(["-s", shell, "shtab.main.get_main_parser", "-o", output])
92+
except SystemExit:
93+
pass
94+
95+
captured = capsys.readouterr()
96+
assert not captured.err
97+
expected = {
98+
"bash": "complete -o filenames -F _shtab_shtab shtab", "zsh": "_shtab_shtab_commands()",
99+
"tcsh": "complete shtab"}
100+
101+
if output in ("-", "stdout"):
102+
assert expected[shell] in captured.out
103+
else:
104+
assert not captured.out
105+
assert expected[shell] in (change_dir / output).read_text()
106+
107+
assert not caplog.record_tuples
108+
109+
85110
@fix_shell
86111
def test_prog_override(shell, caplog, capsys):
87112
with caplog.at_level(logging.INFO):
@@ -342,17 +367,3 @@ def test_path_completion_after_redirection(caplog, change_dir):
342367
shell.test('"${COMPREPLY[@]}" = "test_file.txt"', f"Redirection {redirection} failed")
343368

344369
assert not caplog.record_tuples
345-
346-
347-
def test_extract_stdout(tmp_path):
348-
path = tmp_path / "completions"
349-
with extract_stdout(path) as output:
350-
output.write("completion")
351-
assert path.read_text() == "completion"
352-
353-
354-
def test_extract_stdout_empty(capsys):
355-
with extract_stdout(None) as output:
356-
output.write("completion")
357-
captured = capsys.readouterr()
358-
assert captured.out == "completion"

0 commit comments

Comments
 (0)