Skip to content

flush logs in the end no matter what #761

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 53 additions & 21 deletions hyperactor_mesh/src/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::path::Path;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::RwLock;
use std::task::Context as TaskContext;
use std::task::Poll;
use std::time::Duration;
Expand Down Expand Up @@ -592,7 +593,7 @@ impl Actor for LogForwardActor {
Ok(Self {
rx,
logging_client_ref,
stream_to_client: false,
stream_to_client: true,
})
}

Expand Down Expand Up @@ -668,6 +669,30 @@ pub struct LogClientActor {
/// The watch sender for the aggregation window in seconds
aggregate_window_tx: watch::Sender<u64>,
should_aggregate: bool,
// Store aggregators directly in the actor for access in Drop
aggregators: Arc<RwLock<HashMap<OutputTarget, Aggregator>>>,
}

impl LogClientActor {
fn print_aggregators(aggregators: &RwLock<HashMap<OutputTarget, Aggregator>>) {
let mut aggregators_guard = aggregators.write().unwrap();
for (output_target, aggregator) in aggregators_guard.iter_mut() {
if aggregator.is_empty() {
continue;
}
match output_target {
OutputTarget::Stdout => {
println!("{}", aggregator);
}
OutputTarget::Stderr => {
eprintln!("{}", aggregator);
}
}

// Reset the aggregator
aggregator.reset();
}
}
}

#[async_trait]
Expand All @@ -683,26 +708,42 @@ impl Actor for LogClientActor {
let (aggregate_window_tx, aggregate_window_rx) =
watch::channel(DEFAULT_AGGREGATE_WINDOW_SEC);

// Initialize aggregators
let mut aggregators = HashMap::new();
aggregators.insert(OutputTarget::Stderr, Aggregator::new());
aggregators.insert(OutputTarget::Stdout, Aggregator::new());
let aggregators = Arc::new(RwLock::new(aggregators));

// Clone aggregators for the aggregator task
let aggregators_for_task = Arc::clone(&aggregators);

// Start the loggregator
let aggregator_handle =
{ tokio::spawn(async move { start_aggregator(log_rx, aggregate_window_rx).await }) };
let aggregator_handle = tokio::spawn(async move {
start_aggregator(log_rx, aggregate_window_rx, aggregators_for_task).await
});

Ok(Self {
log_tx,
aggregator_handle,
aggregate_window_tx,
should_aggregate: false,
should_aggregate: true,
aggregators,
})
}
}

impl Drop for LogClientActor {
fn drop(&mut self) {
// Flush the remaining logs before shutting down
Self::print_aggregators(&self.aggregators);
}
}

async fn start_aggregator(
mut log_rx: mpsc::Receiver<(OutputTarget, String)>,
mut interval_sec_rx: watch::Receiver<u64>,
aggregators: Arc<RwLock<HashMap<OutputTarget, Aggregator>>>,
) -> anyhow::Result<()> {
let mut aggregators = HashMap::new();
aggregators.insert(OutputTarget::Stderr, Aggregator::new());
aggregators.insert(OutputTarget::Stdout, Aggregator::new());
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(*interval_sec_rx.borrow()));

Expand All @@ -711,7 +752,8 @@ async fn start_aggregator(
tokio::select! {
// Process incoming log messages
Some((output_target, log_line)) = log_rx.recv() => {
if let Some(aggregator) = aggregators.get_mut(&output_target) {
let mut aggregators_guard = aggregators.write().unwrap();
if let Some(aggregator) = aggregators_guard.get_mut(&output_target) {
if let Err(e) = aggregator.add_line(&log_line) {
tracing::error!("error adding log line: {}", e);
}
Expand All @@ -726,24 +768,14 @@ async fn start_aggregator(

// Every interval tick, print and reset the aggregator
_ = interval.tick() => {
for (output_target, aggregator) in aggregators.iter_mut() {
if aggregator.is_empty() {
continue;
}
if output_target == &OutputTarget::Stdout {
println!("{}", aggregator);
} else {
eprintln!("{}", aggregator);
}

// Reset the aggregator
aggregator.reset();
}
LogClientActor::print_aggregators(&aggregators);
}

// Exit if the channel is closed
else => {
tracing::error!("log channel closed, exiting aggregator");
// Print final aggregated logs before shutting down
LogClientActor::print_aggregators(&aggregators);
break;
}
}
Expand Down
1 change: 0 additions & 1 deletion monarch_extension/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ monarch_types = { version = "0.0.0", path = "../monarch_types" }
nccl-sys = { path = "../nccl-sys", optional = true }
ndslice = { version = "0.0.0", path = "../ndslice" }
pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] }
pyo3-async-runtimes = { version = "0.24", features = ["attributes", "tokio-runtime"] }
rdmaxcel-sys = { path = "../rdmaxcel-sys", optional = true }
serde = { version = "1.0.219", features = ["derive", "rc"] }
tokio = { version = "1.46.1", features = ["full", "test-util", "tracing"] }
Expand Down
6 changes: 6 additions & 0 deletions monarch_extension/src/logging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ impl LoggingMeshClient {
}
}

impl Drop for LoggingMeshClient {
fn drop(&mut self) {
let _ = self.client_actor.drain_and_stop().unwrap();
}
}

pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
module.add_class::<LoggingMeshClient>()?;
Ok(())
Expand Down
3 changes: 2 additions & 1 deletion python/tests/error_test_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe

import asyncio
import ctypes
import sys

import click
from monarch._rust_bindings.monarch_extension.blocking import blocking_function
Expand Down
54 changes: 54 additions & 0 deletions python/tests/python_actor_test_binary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import asyncio
import logging
import sys

import click

from monarch.actor import Actor, endpoint, proc_mesh


@click.group()
def main() -> None:
pass


class Printer(Actor):
def __init__(self) -> None:
self.logger: logging.Logger = logging.getLogger()

@endpoint
async def print(self, content: str) -> None:
print(f"{content}", flush=True)
sys.stdout.flush()
sys.stderr.flush()


async def _flush_logs() -> None:
pm = await proc_mesh(gpus=2)
# never flush
await pm.logging_option(aggregate_window_sec=1000)
am = await pm.spawn("printer", Printer)

# These should be streamed to client
for _ in range(5):
await am.print.call("has print streaming")

# Sleep a tiny so we allow the logs to stream back to the client
await asyncio.sleep(1)


@main.command("flush-logs")
def flush_logs() -> None:
asyncio.run(_flush_logs())


if __name__ == "__main__":
main()
30 changes: 30 additions & 0 deletions python/tests/test_python_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@

# pyre-unsafe
import asyncio
import importlib.resources
import logging
import operator
import os
import re
import subprocess
import sys
import tempfile
import threading
Expand Down Expand Up @@ -713,6 +715,34 @@ async def test_logging_option_defaults() -> None:
pass


# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
@pytest.mark.oss_skip
async def test_flush_logs_fast_exit() -> None:
# We use a subprocess to run the test so we can handle the flushed logs at the end.
# Otherwise, it is hard to restore the original stdout/stderr.

test_bin = importlib.resources.files(str(__package__)).joinpath("test_bin")

# Run the binary in a separate process and capture stdout and stderr
cmd = [str(test_bin), "flush-logs"]
process = subprocess.run(cmd, capture_output=True, timeout=60, text=True)

# Check if the process ended without error
if process.returncode != 0:
raise RuntimeError(f"{cmd} ended with error code {process.returncode}. ")

# Assertions on the captured output
assert (
len(
re.findall(
r"similar.*has print streaming",
process.stdout,
)
)
== 1
), process.stdout


class SendAlot(Actor):
@endpoint
async def send(self, port: Port[int]):
Expand Down