Skip to content

Commit bc97b21

Browse files
committed
refactor
1 parent 20e1c4a commit bc97b21

File tree

3 files changed

+28
-28
lines changed

3 files changed

+28
-28
lines changed

pgcli/main.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,20 @@ def get_completions(self, text, cursor_positition):
12481248
with self._completer_lock:
12491249
return self.completer.get_completions(Document(text=text, cursor_position=cursor_positition), None)
12501250

1251+
def _get_transaction_status_char(self) -> str:
1252+
"""Return transaction status character for prompt.
1253+
1254+
Following psql convention:
1255+
- '*' when in a valid transaction (ACTIVE or INTRANS)
1256+
- '!' when in a failed transaction (INERROR)
1257+
- '' when idle
1258+
"""
1259+
if self.pgexecute.failed_transaction():
1260+
return "!"
1261+
if self.pgexecute.valid_transaction():
1262+
return "*"
1263+
return ""
1264+
12511265
def get_prompt(self, string):
12521266
# should be before replacing \\d
12531267
string = string.replace("\\dsn_alias", self.dsn_alias or "")
@@ -1263,7 +1277,7 @@ def get_prompt(self, string):
12631277
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
12641278
string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
12651279
string = string.replace("\\n", "\n")
1266-
string = string.replace("\\x", self.pgexecute.transaction_status())
1280+
string = string.replace("\\x", self._get_transaction_status_char())
12671281
return string
12681282

12691283
def get_last_query(self):

pgcli/pgexecute.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -298,20 +298,6 @@ def valid_transaction(self):
298298
status = self.conn.info.transaction_status
299299
return status == psycopg.pq.TransactionStatus.ACTIVE or status == psycopg.pq.TransactionStatus.INTRANS
300300

301-
def transaction_status(self):
302-
"""Return transaction status character for prompt.
303-
304-
Following psql convention:
305-
- '*' when in a valid transaction (ACTIVE or INTRANS)
306-
- '!' when in a failed transaction (INERROR)
307-
- '' when idle
308-
"""
309-
if self.failed_transaction():
310-
return "!"
311-
elif self.valid_transaction():
312-
return "*"
313-
return ""
314-
315301
def run(
316302
self,
317303
statement,

tests/test_main.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -578,14 +578,14 @@ def test_duration_in_words(duration_in_seconds, words):
578578

579579

580580
@pytest.mark.parametrize(
581-
"transaction_status,expected",
581+
"is_failed,is_valid,expected",
582582
[
583-
("*", "*testuser"),
584-
("!", "!testuser"),
585-
("", "testuser"),
583+
(False, True, "*testuser"), # valid transaction → "*"
584+
(True, False, "!testuser"), # failed transaction → "!"
585+
(False, False, "testuser"), # idle → ""
586586
],
587587
)
588-
def test_get_prompt_with_transaction_status(transaction_status, expected):
588+
def test_get_prompt_with_transaction_status(is_failed, is_valid, expected):
589589
"""Test that \\x prompt variable shows transaction status."""
590590
cli = PGCli()
591591
cli.pgexecute = mock.MagicMock()
@@ -597,10 +597,10 @@ def test_get_prompt_with_transaction_status(transaction_status, expected):
597597
cli.pgexecute.pid = 12345
598598
cli.pgexecute.superuser = False
599599

600-
with mock.patch.object(
601-
cli.pgexecute, "transaction_status", return_value=transaction_status
602-
):
603-
result = cli.get_prompt("\\x\\u")
600+
cli.pgexecute.failed_transaction.return_value = is_failed
601+
cli.pgexecute.valid_transaction.return_value = is_valid
602+
603+
result = cli.get_prompt("\\x\\u")
604604
assert result == expected
605605

606606

@@ -616,10 +616,10 @@ def test_get_prompt_transaction_status_in_full_prompt():
616616
cli.pgexecute.pid = 12345
617617
cli.pgexecute.superuser = False
618618

619-
with mock.patch.object(
620-
cli.pgexecute, "transaction_status", return_value="*"
621-
):
622-
result = cli.get_prompt("\\x\\u@\\h:\\d> ")
619+
cli.pgexecute.failed_transaction.return_value = False
620+
cli.pgexecute.valid_transaction.return_value = True
621+
622+
result = cli.get_prompt("\\x\\u@\\h:\\d> ")
623623
assert result == "*user@db.example.com:mydb> "
624624

625625

0 commit comments

Comments
 (0)