Skip to content

Commit 6d03f4a

Browse files
authored
Merge branch 'master' into refactor/tensorboard_log_metrics_handling
2 parents e87e9a9 + c05cadb commit 6d03f4a

File tree

4 files changed

+37
-2
lines changed

4 files changed

+37
-2
lines changed

docs/source-pytorch/conf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,4 +639,6 @@ def package_list_from_file(file):
639639
"https://openai.com/index/*",
640640
"https://tinyurl.com/.*", # has a human verification check on redirect
641641
"https://docs.neptune.ai/.*", # TODO: remove after dropping Neptune support
642+
"https://app.neptune.ai/*",
643+
"https://www.neptune.ai/*"
642644
]

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
-
3131

32+
- Fixed ``RichModelSummary`` model size display formatting ([#21467](https://github.com/Lightning-AI/pytorch-lightning/pull/21467))
3233

3334
---
3435

src/lightning/pytorch/callbacks/rich_model_summary.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from lightning.pytorch.callbacks import ModelSummary
1919
from lightning.pytorch.utilities.imports import _RICH_AVAILABLE
20-
from lightning.pytorch.utilities.model_summary import get_human_readable_count
20+
from lightning.pytorch.utilities.model_summary import get_formatted_model_size, get_human_readable_count
2121

2222

2323
class RichModelSummary(ModelSummary):
@@ -105,8 +105,9 @@ def summarize(
105105
console.print(table)
106106

107107
parameters = []
108-
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters, model_size]:
108+
for param in [trainable_parameters, total_parameters - trainable_parameters, total_parameters]:
109109
parameters.append("{:<{}}".format(get_human_readable_count(int(param)), 10))
110+
parameters.append("{:<{}}".format(get_formatted_model_size(model_size), 10))
110111

111112
grid = Table.grid(expand=True)
112113
grid.add_column()

tests/tests_pytorch/callbacks/test_rich_model_summary.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,34 @@ def example_input_array(self) -> Any:
7070
# assert that the input summary data was converted correctly
7171
args, _ = mock_table_add_row.call_args_list[0]
7272
assert args[1:] == ("0", "layer", "Linear", "66 ", "train", "512 ", "[4, 32]", "[4, 2]")
73+
74+
75+
@RunIf(rich=True)
76+
def test_rich_summary_model_size_formatting():
77+
"""Ensure model_size uses get_formatted_model_size, not get_human_readable_count."""
78+
from io import StringIO
79+
80+
from rich.console import Console
81+
82+
model_summary = RichModelSummary()
83+
model = BoringModel()
84+
summary = summarize(model)
85+
summary_data = summary._get_summary_data()
86+
87+
output = StringIO()
88+
console = Console(file=output, force_terminal=True)
89+
90+
with mock.patch("rich.get_console", return_value=console):
91+
model_summary.summarize(
92+
summary_data=summary_data,
93+
total_parameters=1,
94+
trainable_parameters=1,
95+
model_size=5500.0,
96+
total_training_modes=summary.total_training_modes,
97+
total_flops=1,
98+
)
99+
100+
result = output.getvalue()
101+
# model_size=5500.0 should display as "5,500.000" (formatted), not "5.5 K" (human readable count)
102+
assert "5,500.000" in result
103+
assert "5.5 K" not in result

0 commit comments

Comments
 (0)