|
16 | 16 |
|
17 | 17 | # conftest.py |
18 | 18 | import gc |
19 | | -import os |
20 | | -import random |
21 | | -import signal |
22 | | -import time |
23 | 19 | from pathlib import Path |
24 | 20 |
|
25 | | -import numpy as np |
26 | 21 | import pytest |
27 | 22 | import torch |
28 | 23 |
|
29 | | -from .utils import clean_up_distributed_and_parallel_states |
30 | | - |
31 | 24 |
|
32 | 25 | def get_device_and_memory_allocated() -> str: |
33 | 26 | """Get the current device index, name, and memory usage.""" |
@@ -68,130 +61,13 @@ def pytest_sessionfinish(session, exitstatus): |
68 | 61 | ) |
69 | 62 |
|
70 | 63 |
|
71 | | -def _cleanup_child_processes(): |
72 | | - """Kill any orphaned child processes that might be holding GPU memory. |
73 | | -
|
74 | | - This is particularly important for tests that spawn subprocesses via torchrun. |
75 | | -
|
76 | | - Note: Skips cleanup when distributed is initialized, as killing processes |
77 | | - could interfere with NCCL's internal state and cause hangs. |
78 | | - """ |
79 | | - # Don't kill child processes if distributed is active - NCCL might have |
80 | | - # internal processes that should not be killed |
81 | | - if torch.distributed.is_initialized(): |
82 | | - return |
83 | | - |
84 | | - import subprocess |
85 | | - |
86 | | - current_pid = os.getpid() |
87 | | - try: |
88 | | - # Find child processes |
89 | | - result = subprocess.run( |
90 | | - ["pgrep", "-P", str(current_pid)], check=False, capture_output=True, text=True, timeout=5 |
91 | | - ) |
92 | | - child_pids = result.stdout.strip().split("\n") |
93 | | - for pid_str in child_pids: |
94 | | - if pid_str: |
95 | | - try: |
96 | | - pid = int(pid_str) |
97 | | - os.kill(pid, signal.SIGTERM) |
98 | | - except (ValueError, ProcessLookupError, PermissionError): |
99 | | - pass |
100 | | - except (subprocess.TimeoutExpired, FileNotFoundError): |
101 | | - pass |
102 | | - |
103 | | - |
104 | | -def _thorough_gpu_cleanup(): |
105 | | - """Perform thorough GPU memory cleanup. |
106 | | -
|
107 | | - Note: This is intentionally gentle when torch.distributed is initialized, |
108 | | - as aggressive cleanup can interfere with NCCL state and cause hangs in |
109 | | - subsequent tests that reinitialize distributed. |
110 | | - """ |
111 | | - if not torch.cuda.is_available(): |
112 | | - return |
113 | | - |
114 | | - # If distributed is still initialized, skip aggressive cleanup to avoid |
115 | | - # interfering with NCCL's internal state. The test fixture should handle |
116 | | - # distributed cleanup before this runs, but if it hasn't, we don't want |
117 | | - # to cause issues. |
118 | | - if torch.distributed.is_initialized(): |
119 | | - # Just do minimal cleanup - gc only |
120 | | - gc.collect() |
121 | | - return |
122 | | - |
123 | | - # Synchronize all CUDA streams to ensure all operations are complete |
124 | | - torch.cuda.synchronize() |
125 | | - |
126 | | - # Clear all cached memory |
127 | | - torch.cuda.empty_cache() |
128 | | - |
129 | | - # Reset peak memory stats |
130 | | - torch.cuda.reset_peak_memory_stats() |
131 | | - |
132 | | - # Run garbage collection multiple times to ensure all objects are collected |
133 | | - for _ in range(3): |
134 | | - gc.collect() |
135 | | - |
136 | | - # Another sync and cache clear after gc |
137 | | - torch.cuda.synchronize() |
138 | | - torch.cuda.empty_cache() |
139 | | - |
140 | | - # Small sleep to allow GPU memory to be fully released |
141 | | - time.sleep(0.1) |
142 | | - |
143 | | - |
144 | | -def _reset_random_seeds(): |
145 | | - """Reset random seeds to ensure reproducibility across tests. |
146 | | -
|
147 | | - Some tests may modify global random state, which can affect subsequent tests |
148 | | - that depend on random splitting (like dataset preprocessing). |
149 | | -
|
150 | | - Note: Skips CUDA seed reset when distributed is initialized, as this can |
151 | | - interfere with Megatron's tensor parallel RNG tracker. |
152 | | - """ |
153 | | - # Reset Python's random module |
154 | | - random.seed(None) |
155 | | - |
156 | | - # Reset NumPy's random state (intentionally using legacy API to reset global state) |
157 | | - np.random.seed(None) # noqa: NPY002 |
158 | | - |
159 | | - # Reset PyTorch's random state |
160 | | - torch.seed() |
161 | | - |
162 | | - # Only reset CUDA seeds if distributed is not initialized, as the distributed |
163 | | - # tests manage their own RNG state via model_parallel_cuda_manual_seed |
164 | | - if torch.cuda.is_available() and not torch.distributed.is_initialized(): |
165 | | - torch.cuda.seed_all() |
166 | | - |
167 | | - |
168 | 64 | @pytest.fixture(autouse=True) |
169 | 65 | def cleanup_after_test(): |
170 | | - """Clean up GPU memory and reset state after each test. |
171 | | -
|
172 | | - This fixture provides a safety net for tests that may not properly clean up |
173 | | - their distributed/parallel state. It uses the shared cleanup function from |
174 | | - utils.py as the canonical cleanup, then performs additional GPU memory cleanup. |
175 | | - """ |
176 | | - # Reset random seeds before the test to ensure reproducibility |
177 | | - _reset_random_seeds() |
178 | | - |
| 66 | + """Clean up GPU memory and reset state after each test.""" |
179 | 67 | yield |
180 | | - |
181 | | - # First, ensure any lingering distributed/parallel state is cleaned up. |
182 | | - # This is a safety net - tests using distributed_model_parallel_state should |
183 | | - # already have cleaned up, but this catches any that didn't. |
184 | | - # This function is safe to call even if distributed is not initialized. |
185 | | - clean_up_distributed_and_parallel_states() |
186 | | - |
187 | | - # After distributed cleanup, perform thorough GPU memory cleanup |
188 | | - _thorough_gpu_cleanup() |
189 | | - |
190 | | - # Clean up any orphaned child processes (important for subprocess tests) |
191 | | - _cleanup_child_processes() |
192 | | - |
193 | | - # Final garbage collection |
194 | | - gc.collect() |
| 68 | + if torch.cuda.is_available(): |
| 69 | + torch.cuda.empty_cache() |
| 70 | + gc.collect() |
195 | 71 |
|
196 | 72 |
|
197 | 73 | def pytest_addoption(parser: pytest.Parser): |
|
0 commit comments