Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/fairseq2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

import fairseq2n # Report any fairseq2n initialization error eagerly.

from fairseq2.device import Device
from fairseq2.utils.warn import enable_deprecation_warnings
import fairseq2.runtime.dependency
from fairseq2.error import InvalidOperationError
from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver
Expand Down
4 changes: 4 additions & 0 deletions src/fairseq2/recipe/composition/device_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from fairseq2.recipe.internal.device_stat import _RecipeDeviceStatTrackerProvider
from fairseq2.runtime.dependency import DependencyContainer, DependencyResolver
from fairseq2.utils.cpu_stat import CpuDeviceStatTracker
from fairseq2.utils.device_stat import CudaDeviceStatTracker, DeviceStatTracker


Expand All @@ -22,5 +23,8 @@ def get_device_stat_tracker(resolver: DependencyResolver) -> DeviceStatTracker:

container.register_type(_RecipeDeviceStatTrackerProvider)

# CPU
container.register_type(DeviceStatTracker, CpuDeviceStatTracker, key="cpu")

# CUDA
container.register_type(DeviceStatTracker, CudaDeviceStatTracker, key="cuda")
88 changes: 88 additions & 0 deletions src/fairseq2/utils/cpu_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import os
from typing import final

from typing_extensions import override

from fairseq2.device import Device
from fairseq2.error import OperationalError
from fairseq2.utils.device_stat import DeviceStatTracker

try:
import psutil
except ImportError:
psutil = None # type: ignore[assignment]


@final
class CpuDeviceStatTracker(DeviceStatTracker):
"""Tracks CPU and memory statistics for the current process."""

def __init__(self, device: Device) -> None:
if device.type != "cpu":
raise ValueError(
f"`device.type` must be `cpu`, but is `{device.type}` instead."
)

if psutil is None:
raise OperationalError(
"psutil is not installed. Install it with: pip install psutil"
)

self._device = device
self._process = psutil.Process(os.getpid())

self._peak_memory_rss = 0
self._peak_memory_vms = 0

@override
def get_stats(self) -> dict[str, object]:
try:
mem_info = self._process.memory_info()
mem_percent = self._process.memory_percent()

current_rss = mem_info.rss
current_vms = mem_info.vms

self._peak_memory_rss = max(self._peak_memory_rss, current_rss)
self._peak_memory_vms = max(self._peak_memory_vms, current_vms)

stats = {
"peak_memory_rss_bytes": self._peak_memory_rss,
"peak_memory_vms_bytes": self._peak_memory_vms,
"memory_percent": mem_percent,
"num_threads": self._process.num_threads(),
}

cpu_percent = self._process.cpu_percent(interval=None)
if cpu_percent is not None and cpu_percent > 0:
stats["cpu_percent"] = cpu_percent

try:
load_avg = os.getloadavg()
stats["load_average_1m"] = load_avg[0]
stats["load_average_5m"] = load_avg[1]
stats["load_average_15m"] = load_avg[2]
except (AttributeError, OSError):
pass

return stats

except Exception as ex:
raise OperationalError("Failed to collect CPU statistics.") from ex

@override
def reset(self) -> None:
self._peak_memory_rss = 0
self._peak_memory_vms = 0
try:
self._process.cpu_percent(interval=None)
except Exception:
pass
2 changes: 2 additions & 0 deletions src/fairseq2/utils/device_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from fairseq2.error import OperationalError




class DeviceStatTracker(ABC):
@abstractmethod
def get_stats(self) -> dict[str, object]: ...
Expand Down
62 changes: 62 additions & 0 deletions test_cpu_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#!/usr/bin/env python3
"""
DEVELOPMENT ONLY FILE

Simple test script to verify CPU stat tracker implementation.
Run this after installing fairseq2 with: pip install -e .
"""

from fairseq2.device import Device
from fairseq2.utils.cpu_stat import CpuDeviceStatTracker


def test_cpu_tracker():
print("Testing CpuDeviceStatTracker...")

device = Device("cpu")
tracker = CpuDeviceStatTracker(device)

print("\n1. Getting initial stats...")
stats = tracker.get_stats()

print("Stats collected:")
for key, value in stats.items():
if isinstance(value, float):
print(f" {key}: {value:.2f}")
else:
print(f" {key}: {value}")

print("\n2. Allocating some memory...")
data = [0] * 1000000

print("\n3. Getting stats after allocation...")
stats_after = tracker.get_stats()

print("Stats after allocation:")
for key, value in stats_after.items():
if isinstance(value, float):
print(f" {key}: {value:.2f}")
else:
print(f" {key}: {value}")

print("\n4. Verifying peak memory increased...")
assert stats_after["peak_memory_rss_bytes"] >= stats["peak_memory_rss_bytes"]
print("✓ Peak memory tracking works!")

print("\n5. Resetting tracker...")
tracker.reset()

print("\n6. Getting stats after reset...")
stats_reset = tracker.get_stats()
print("Stats after reset:")
for key, value in stats_reset.items():
if isinstance(value, float):
print(f" {key}: {value:.2f}")
else:
print(f" {key}: {value}")

print("\n✓ All tests passed!")


if __name__ == "__main__":
test_cpu_tracker()
6 changes: 4 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


from __future__ import annotations

import tests.common
from argparse import ArgumentTypeError
from pathlib import Path
from typing import cast

from pytest import Config, Parser, Session

import tests.common

from fairseq2 import init_fairseq2
from fairseq2.device import Device
from fairseq2.utils.warn import enable_deprecation_warnings
from fairseq2 import enable_deprecation_warnings


def pytest_addoption(parser: Parser) -> None:
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/recipe/test_device_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import pytest

from fairseq2.device import Device
from fairseq2.utils.cpu_stat import CpuDeviceStatTracker
from fairseq2.utils.device_stat import CudaDeviceStatTracker


class TestDeviceStatTrackers:
"""Test device stat tracker implementations directly."""

def test_cpu_tracker_collects_stats(self) -> None:
device = Device("cpu")
tracker = CpuDeviceStatTracker(device)

stats = tracker.get_stats()
assert isinstance(stats, dict)
assert "peak_memory_rss_bytes" in stats or "memory_percent" in stats

# Test reset
tracker.reset()
stats_after_reset = tracker.get_stats()
assert isinstance(stats_after_reset, dict)

def test_cuda_tracker_collects_stats(self) -> None:
try:
device = Device("cuda:0")
except RuntimeError:
pytest.skip("CUDA not available")

try:
from fairseq2.device import CudaContext
cuda_context = CudaContext()
tracker = CudaDeviceStatTracker(device, cuda_context)

stats = tracker.get_stats()
assert isinstance(stats, dict)
assert "peak_active_mem_bytes" in stats or "peak_reserved_mem_bytes" in stats

# Test reset
tracker.reset()
stats_after_reset = tracker.get_stats()
assert isinstance(stats_after_reset, dict)
except Exception:
pytest.skip("CUDA context not available")
59 changes: 59 additions & 0 deletions tests/unit/utils/test_cpu_stat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import pytest

from fairseq2.device import Device
from fairseq2.error import OperationalError
from fairseq2.utils.cpu_stat import CpuDeviceStatTracker


class TestCpuDeviceStatTracker:
def test_init_validates_device_type(self) -> None:
with pytest.raises(ValueError, match=r"`device.type` must be `cpu`"):
CpuDeviceStatTracker(Device("cuda"))

def test_get_stats_returns_valid_metrics(self) -> None:
tracker = CpuDeviceStatTracker(Device("cpu"))

stats = tracker.get_stats()

assert "peak_memory_rss_bytes" in stats
assert "peak_memory_vms_bytes" in stats
assert "memory_percent" in stats
assert "num_threads" in stats

assert stats["peak_memory_rss_bytes"] > 0
assert stats["peak_memory_vms_bytes"] > 0
assert 0 <= stats["memory_percent"] <= 100
assert stats["num_threads"] > 0

def test_get_stats_tracks_peak_memory(self) -> None:
tracker = CpuDeviceStatTracker(Device("cpu"))

stats1 = tracker.get_stats()
peak1 = stats1["peak_memory_rss_bytes"]

_ = [0] * 10000

stats2 = tracker.get_stats()
peak2 = stats2["peak_memory_rss_bytes"]

assert peak2 >= peak1

def test_reset_clears_peak_memory(self) -> None:
tracker = CpuDeviceStatTracker(Device("cpu"))

_ = tracker.get_stats()

tracker.reset()

stats = tracker.get_stats()
assert stats["peak_memory_rss_bytes"] >= 0
assert stats["peak_memory_vms_bytes"] >= 0