diff --git a/src/fairseq2/__init__.py b/src/fairseq2/__init__.py index d77e35194..420ff999b 100644 --- a/src/fairseq2/__init__.py +++ b/src/fairseq2/__init__.py @@ -10,6 +10,8 @@ import fairseq2n # Report any fairseq2n initialization error eagerly. +from fairseq2.device import Device +from fairseq2.utils.warn import enable_deprecation_warnings import fairseq2.runtime.dependency from fairseq2.error import InvalidOperationError from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver diff --git a/src/fairseq2/recipe/composition/device_stat.py b/src/fairseq2/recipe/composition/device_stat.py index 530f973a4..6171f61c4 100644 --- a/src/fairseq2/recipe/composition/device_stat.py +++ b/src/fairseq2/recipe/composition/device_stat.py @@ -8,6 +8,7 @@ from fairseq2.recipe.internal.device_stat import _RecipeDeviceStatTrackerProvider from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver +from fairseq2.utils.cpu_stat import CpuDeviceStatTracker from fairseq2.utils.device_stat import CudaDeviceStatTracker, DeviceStatTracker @@ -22,5 +23,8 @@ def get_device_stat_tracker(resolver: DependencyResolver) -> DeviceStatTracker: container.register_type(_RecipeDeviceStatTrackerProvider) + # CPU + container.register_type(DeviceStatTracker, CpuDeviceStatTracker, key="cpu") + # CUDA container.register_type(DeviceStatTracker, CudaDeviceStatTracker, key="cuda") diff --git a/src/fairseq2/utils/cpu_stat.py b/src/fairseq2/utils/cpu_stat.py new file mode 100644 index 000000000..9cc2915a9 --- /dev/null +++ b/src/fairseq2/utils/cpu_stat.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import os +from typing import final + +from typing_extensions import override + +from fairseq2.device import Device +from fairseq2.error import OperationalError +from fairseq2.utils.device_stat import DeviceStatTracker + +try: + import psutil +except ImportError: + psutil = None # type: ignore[assignment] + + +@final +class CpuDeviceStatTracker(DeviceStatTracker): + """Tracks CPU and memory statistics for the current process.""" + + def __init__(self, device: Device) -> None: + if device.type != "cpu": + raise ValueError( + f"`device.type` must be `cpu`, but is `{device.type}` instead." + ) + + if psutil is None: + raise OperationalError( + "psutil is not installed. Install it with: pip install psutil" + ) + + self._device = device + self._process = psutil.Process(os.getpid()) + + self._peak_memory_rss = 0 + self._peak_memory_vms = 0 + + @override + def get_stats(self) -> dict[str, object]: + try: + mem_info = self._process.memory_info() + mem_percent = self._process.memory_percent() + + current_rss = mem_info.rss + current_vms = mem_info.vms + + self._peak_memory_rss = max(self._peak_memory_rss, current_rss) + self._peak_memory_vms = max(self._peak_memory_vms, current_vms) + + stats = { + "peak_memory_rss_bytes": self._peak_memory_rss, + "peak_memory_vms_bytes": self._peak_memory_vms, + "memory_percent": mem_percent, + "num_threads": self._process.num_threads(), + } + + cpu_percent = self._process.cpu_percent(interval=None) + if cpu_percent is not None and cpu_percent > 0: + stats["cpu_percent"] = cpu_percent + + try: + load_avg = os.getloadavg() + stats["load_average_1m"] = load_avg[0] + stats["load_average_5m"] = load_avg[1] + stats["load_average_15m"] = load_avg[2] + except (AttributeError, OSError): + pass + + return stats + + except Exception as ex: + raise OperationalError("Failed to collect CPU statistics.") from ex + + @override + def reset(self) -> None: + self._peak_memory_rss = 0 + self._peak_memory_vms = 0 + try: + self._process.cpu_percent(interval=None) + except Exception: + pass diff --git a/src/fairseq2/utils/device_stat.py b/src/fairseq2/utils/device_stat.py index af99a7722..8a47456fe 100644 --- a/src/fairseq2/utils/device_stat.py +++ b/src/fairseq2/utils/device_stat.py @@ -15,6 +15,8 @@ from fairseq2.error import OperationalError + + class DeviceStatTracker(ABC): @abstractmethod def get_stats(self) -> dict[str, object]: ... diff --git a/test_cpu_tracker.py b/test_cpu_tracker.py new file mode 100644 index 000000000..7e415a51c --- /dev/null +++ b/test_cpu_tracker.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +""" +DEVELOPMENT ONLY FILE + +Simple test script to verify CPU stat tracker implementation. +Run this after installing fairseq2 with: pip install -e . +""" + +from fairseq2.device import Device +from fairseq2.utils.cpu_stat import CpuDeviceStatTracker + + +def test_cpu_tracker(): + print("Testing CpuDeviceStatTracker...") + + device = Device("cpu") + tracker = CpuDeviceStatTracker(device) + + print("\n1. Getting initial stats...") + stats = tracker.get_stats() + + print("Stats collected:") + for key, value in stats.items(): + if isinstance(value, float): + print(f" {key}: {value:.2f}") + else: + print(f" {key}: {value}") + + print("\n2. Allocating some memory...") + data = [0] * 1000000 + + print("\n3. Getting stats after allocation...") + stats_after = tracker.get_stats() + + print("Stats after allocation:") + for key, value in stats_after.items(): + if isinstance(value, float): + print(f" {key}: {value:.2f}") + else: + print(f" {key}: {value}") + + print("\n4. Verifying peak memory increased...") + assert stats_after["peak_memory_rss_bytes"] >= stats["peak_memory_rss_bytes"] + print("āœ“ Peak memory tracking works!") + + print("\n5. Resetting tracker...") + tracker.reset() + + print("\n6. Getting stats after reset...") + stats_reset = tracker.get_stats() + print("Stats after reset:") + for key, value in stats_reset.items(): + if isinstance(value, float): + print(f" {key}: {value:.2f}") + else: + print(f" {key}: {value}") + + print("\nāœ“ All tests passed!") + + +if __name__ == "__main__": + test_cpu_tracker() diff --git a/tests/conftest.py b/tests/conftest.py index 13e36e388..5d40483f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,18 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. + from __future__ import annotations +import tests.common from argparse import ArgumentTypeError from pathlib import Path from typing import cast from pytest import Config, Parser, Session -import tests.common + from fairseq2 import init_fairseq2 from fairseq2.device import Device -from fairseq2.utils.warn import enable_deprecation_warnings +from fairseq2 import enable_deprecation_warnings def pytest_addoption(parser: Parser) -> None: diff --git a/tests/unit/recipe/test_device_stat.py b/tests/unit/recipe/test_device_stat.py new file mode 100644 index 000000000..a1355203b --- /dev/null +++ b/tests/unit/recipe/test_device_stat.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import pytest + +from fairseq2.device import Device +from fairseq2.utils.cpu_stat import CpuDeviceStatTracker +from fairseq2.utils.device_stat import CudaDeviceStatTracker + + +class TestDeviceStatTrackers: + """Test device stat tracker implementations directly.""" + + def test_cpu_tracker_collects_stats(self) -> None: + device = Device("cpu") + tracker = CpuDeviceStatTracker(device) + + stats = tracker.get_stats() + assert isinstance(stats, dict) + assert "peak_memory_rss_bytes" in stats or "memory_percent" in stats + + # Test reset + tracker.reset() + stats_after_reset = tracker.get_stats() + assert isinstance(stats_after_reset, dict) + + def test_cuda_tracker_collects_stats(self) -> None: + try: + device = Device("cuda:0") + except RuntimeError: + pytest.skip("CUDA not available") + + try: + from fairseq2.device import CudaContext + cuda_context = CudaContext() + tracker = CudaDeviceStatTracker(device, cuda_context) + + stats = tracker.get_stats() + assert isinstance(stats, dict) + assert "peak_active_mem_bytes" in stats or "peak_reserved_mem_bytes" in stats + + # Test reset + tracker.reset() + stats_after_reset = tracker.get_stats() + assert isinstance(stats_after_reset, dict) + except Exception: + pytest.skip("CUDA context not available") \ No newline at end of file diff --git a/tests/unit/utils/test_cpu_stat.py b/tests/unit/utils/test_cpu_stat.py new file mode 100644 index 000000000..8c5d99912 --- /dev/null +++ b/tests/unit/utils/test_cpu_stat.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import pytest + +from fairseq2.device import Device +from fairseq2.error import OperationalError +from fairseq2.utils.cpu_stat import CpuDeviceStatTracker + + +class TestCpuDeviceStatTracker: + def test_init_validates_device_type(self) -> None: + with pytest.raises(ValueError, match=r"`device.type` must be `cpu`"): + CpuDeviceStatTracker(Device("cuda")) + + def test_get_stats_returns_valid_metrics(self) -> None: + tracker = CpuDeviceStatTracker(Device("cpu")) + + stats = tracker.get_stats() + + assert "peak_memory_rss_bytes" in stats + assert "peak_memory_vms_bytes" in stats + assert "memory_percent" in stats + assert "num_threads" in stats + + assert stats["peak_memory_rss_bytes"] > 0 + assert stats["peak_memory_vms_bytes"] > 0 + assert 0 <= stats["memory_percent"] <= 100 + assert stats["num_threads"] > 0 + + def test_get_stats_tracks_peak_memory(self) -> None: + tracker = CpuDeviceStatTracker(Device("cpu")) + + stats1 = tracker.get_stats() + peak1 = stats1["peak_memory_rss_bytes"] + + _ = [0] * 10000 + + stats2 = tracker.get_stats() + peak2 = stats2["peak_memory_rss_bytes"] + + assert peak2 >= peak1 + + def test_reset_clears_peak_memory(self) -> None: + tracker = CpuDeviceStatTracker(Device("cpu")) + + _ = tracker.get_stats() + + tracker.reset() + + stats = tracker.get_stats() + assert stats["peak_memory_rss_bytes"] >= 0 + assert stats["peak_memory_vms_bytes"] >= 0 + \ No newline at end of file