Skip to content

Commit dff8ed1

Browse files
committed
Restore conftest before the AI additions to see if that fixes the stalling problems
Signed-off-by: John St. John <jstjohn@nvidia.com>
1 parent d0864f6 commit dff8ed1

File tree

3 files changed

+62
-138
lines changed

3 files changed

+62
-138
lines changed

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/conftest.py

Lines changed: 4 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,11 @@
1616

1717
# conftest.py
1818
import gc
19-
import os
20-
import random
21-
import signal
22-
import time
2319
from pathlib import Path
2420

25-
import numpy as np
2621
import pytest
2722
import torch
2823

29-
from .utils import clean_up_distributed_and_parallel_states
30-
3124

3225
def get_device_and_memory_allocated() -> str:
3326
"""Get the current device index, name, and memory usage."""
@@ -68,130 +61,13 @@ def pytest_sessionfinish(session, exitstatus):
6861
)
6962

7063

71-
def _cleanup_child_processes():
72-
"""Kill any orphaned child processes that might be holding GPU memory.
73-
74-
This is particularly important for tests that spawn subprocesses via torchrun.
75-
76-
Note: Skips cleanup when distributed is initialized, as killing processes
77-
could interfere with NCCL's internal state and cause hangs.
78-
"""
79-
# Don't kill child processes if distributed is active - NCCL might have
80-
# internal processes that should not be killed
81-
if torch.distributed.is_initialized():
82-
return
83-
84-
import subprocess
85-
86-
current_pid = os.getpid()
87-
try:
88-
# Find child processes
89-
result = subprocess.run(
90-
["pgrep", "-P", str(current_pid)], check=False, capture_output=True, text=True, timeout=5
91-
)
92-
child_pids = result.stdout.strip().split("\n")
93-
for pid_str in child_pids:
94-
if pid_str:
95-
try:
96-
pid = int(pid_str)
97-
os.kill(pid, signal.SIGTERM)
98-
except (ValueError, ProcessLookupError, PermissionError):
99-
pass
100-
except (subprocess.TimeoutExpired, FileNotFoundError):
101-
pass
102-
103-
104-
def _thorough_gpu_cleanup():
105-
"""Perform thorough GPU memory cleanup.
106-
107-
Note: This is intentionally gentle when torch.distributed is initialized,
108-
as aggressive cleanup can interfere with NCCL state and cause hangs in
109-
subsequent tests that reinitialize distributed.
110-
"""
111-
if not torch.cuda.is_available():
112-
return
113-
114-
# If distributed is still initialized, skip aggressive cleanup to avoid
115-
# interfering with NCCL's internal state. The test fixture should handle
116-
# distributed cleanup before this runs, but if it hasn't, we don't want
117-
# to cause issues.
118-
if torch.distributed.is_initialized():
119-
# Just do minimal cleanup - gc only
120-
gc.collect()
121-
return
122-
123-
# Synchronize all CUDA streams to ensure all operations are complete
124-
torch.cuda.synchronize()
125-
126-
# Clear all cached memory
127-
torch.cuda.empty_cache()
128-
129-
# Reset peak memory stats
130-
torch.cuda.reset_peak_memory_stats()
131-
132-
# Run garbage collection multiple times to ensure all objects are collected
133-
for _ in range(3):
134-
gc.collect()
135-
136-
# Another sync and cache clear after gc
137-
torch.cuda.synchronize()
138-
torch.cuda.empty_cache()
139-
140-
# Small sleep to allow GPU memory to be fully released
141-
time.sleep(0.1)
142-
143-
144-
def _reset_random_seeds():
145-
"""Reset random seeds to ensure reproducibility across tests.
146-
147-
Some tests may modify global random state, which can affect subsequent tests
148-
that depend on random splitting (like dataset preprocessing).
149-
150-
Note: Skips CUDA seed reset when distributed is initialized, as this can
151-
interfere with Megatron's tensor parallel RNG tracker.
152-
"""
153-
# Reset Python's random module
154-
random.seed(None)
155-
156-
# Reset NumPy's random state (intentionally using legacy API to reset global state)
157-
np.random.seed(None) # noqa: NPY002
158-
159-
# Reset PyTorch's random state
160-
torch.seed()
161-
162-
# Only reset CUDA seeds if distributed is not initialized, as the distributed
163-
# tests manage their own RNG state via model_parallel_cuda_manual_seed
164-
if torch.cuda.is_available() and not torch.distributed.is_initialized():
165-
torch.cuda.seed_all()
166-
167-
16864
@pytest.fixture(autouse=True)
16965
def cleanup_after_test():
170-
"""Clean up GPU memory and reset state after each test.
171-
172-
This fixture provides a safety net for tests that may not properly clean up
173-
their distributed/parallel state. It uses the shared cleanup function from
174-
utils.py as the canonical cleanup, then performs additional GPU memory cleanup.
175-
"""
176-
# Reset random seeds before the test to ensure reproducibility
177-
_reset_random_seeds()
178-
66+
"""Clean up GPU memory and reset state after each test."""
17967
yield
180-
181-
# First, ensure any lingering distributed/parallel state is cleaned up.
182-
# This is a safety net - tests using distributed_model_parallel_state should
183-
# already have cleaned up, but this catches any that didn't.
184-
# This function is safe to call even if distributed is not initialized.
185-
clean_up_distributed_and_parallel_states()
186-
187-
# After distributed cleanup, perform thorough GPU memory cleanup
188-
_thorough_gpu_cleanup()
189-
190-
# Clean up any orphaned child processes (important for subprocess tests)
191-
_cleanup_child_processes()
192-
193-
# Final garbage collection
194-
gc.collect()
68+
if torch.cuda.is_available():
69+
torch.cuda.empty_cache()
70+
gc.collect()
19571

19672

19773
def pytest_addoption(parser: pytest.Parser):

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_evo2.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,13 @@ def calculate_sequence_identity(seq1: str, seq2: str) -> float | None:
541541
@pytest.mark.parametrize(
542542
"ckpt_name,expected_matchpercents,fp8",
543543
[
544-
pytest.param("evo2/1b-8k-bf16:1.0", [86.4, 78.8, 49.7], False, id="1b-bf16_bf16"),
544+
pytest.param(
545+
"evo2/1b-8k-bf16:1.0",
546+
[86.4, 78.8, 49.7],
547+
False,
548+
id="1b-bf16_bf16",
549+
marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to slow speed"),
550+
),
545551
pytest.param("evo2/1b-8k-bf16:1.0", [86.4, 78.8, 49.7], True, id="1b-bf16_fp8"),
546552
pytest.param(
547553
"evo2/1b-8k:1.0",
@@ -701,7 +707,13 @@ def test_batch_generate_coding_sequences(
701707
@pytest.mark.parametrize(
702708
"ckpt_name,expected_matchpercents,fp8",
703709
[
704-
pytest.param("evo2/1b-8k-bf16:1.0", [96.8, 29.7, 76.6, 71.6], False, id="1b-bf16_bf16"),
710+
pytest.param(
711+
"evo2/1b-8k-bf16:1.0",
712+
[96.8, 29.7, 76.6, 71.6],
713+
False,
714+
id="1b-bf16_bf16",
715+
marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI due to slow speed"),
716+
),
705717
pytest.param("evo2/1b-8k-bf16:1.0", [96.8, 29.7, 76.6, 71.6], True, id="1b-bf16_fp8"),
706718
pytest.param(
707719
"evo2/1b-8k:1.0",

bionemo-recipes/recipes/evo2_megatron/tests/bionemo/evo2/test_stop_and_go.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,50 @@
3838
[
3939
(1, 1, 1, False, "bf16_mixed"),
4040
(1, 1, 1, False, "bf16_with_fp8_current_scaling_mixed"),
41-
(1, 1, 1, False, "bf16_with_fp8_delayed_scaling_mixed"), # XFAIL
42-
(1, 1, 1, False, "bf16_with_fp8_subchannel_scaling_mixed"),
43-
(1, 1, 1, False, "bf16_with_nvfp4_mixed"), # XFAIL other than blackwell+
44-
(1, 1, 1, False, "bf16_with_mxfp8_mixed"), # XFAIL other than blackwell+
45-
(1, 1, 2, True, "bf16_mixed"),
46-
(1, 1, 2, False, "bf16_mixed"),
47-
(1, 2, 1, True, "bf16_mixed"),
48-
(2, 1, 1, False, "bf16_mixed"),
41+
pytest.param(
42+
1,
43+
1,
44+
1,
45+
False,
46+
"bf16_with_fp8_delayed_scaling_mixed",
47+
marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI"),
48+
), # XFAIL
49+
pytest.param(
50+
1,
51+
1,
52+
1,
53+
False,
54+
"bf16_with_fp8_subchannel_scaling_mixed",
55+
marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI"),
56+
),
57+
pytest.param(
58+
1,
59+
1,
60+
1,
61+
False,
62+
"bf16_with_nvfp4_mixed",
63+
marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI"),
64+
), # XFAIL other than blackwell+
65+
pytest.param(
66+
1,
67+
1,
68+
1,
69+
False,
70+
"bf16_with_mxfp8_mixed",
71+
marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI"),
72+
), # XFAIL other than blackwell+
73+
pytest.param(
74+
1, 1, 2, True, "bf16_mixed", marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI")
75+
),
76+
pytest.param(
77+
1, 1, 2, False, "bf16_mixed", marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI")
78+
),
79+
pytest.param(
80+
1, 2, 1, True, "bf16_mixed", marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI")
81+
),
82+
pytest.param(
83+
2, 1, 1, False, "bf16_mixed", marks=pytest.mark.skipif(bool(os.environ.get("CI")), reason="Skip in CI")
84+
),
4985
],
5086
)
5187
@pytest.mark.slow

0 commit comments

Comments
 (0)