Skip to content

Commit 4cb03a5

Browse files
James Sunfacebook-github-bot
authored andcommitted
flush logs in the end no matter what (#761)
Summary: make sure we are able to flush whatever remaining when program ends. However, we could still miss logs as log through channel is not always guaranteed to be streamed back when a program ends. Differential Revision: D79620358
1 parent efbcb45 commit 4cb03a5

File tree

4 files changed

+134
-20
lines changed

4 files changed

+134
-20
lines changed

hyperactor_mesh/src/logging.rs

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use std::path::Path;
1212
use std::path::PathBuf;
1313
use std::pin::Pin;
1414
use std::sync::Arc;
15+
use std::sync::RwLock;
1516
use std::task::Context as TaskContext;
1617
use std::task::Poll;
1718
use std::time::Duration;
@@ -668,6 +669,30 @@ pub struct LogClientActor {
668669
/// The watch sender for the aggregation window in seconds
669670
aggregate_window_tx: watch::Sender<u64>,
670671
should_aggregate: bool,
672+
// Store aggregators directly in the actor for access in Drop
673+
aggregators: Arc<RwLock<HashMap<OutputTarget, Aggregator>>>,
674+
}
675+
676+
impl LogClientActor {
677+
fn print_aggregators(aggregators: &RwLock<HashMap<OutputTarget, Aggregator>>) {
678+
let mut aggregators_guard = aggregators.write().unwrap();
679+
for (output_target, aggregator) in aggregators_guard.iter_mut() {
680+
if aggregator.is_empty() {
681+
continue;
682+
}
683+
match output_target {
684+
OutputTarget::Stdout => {
685+
println!("{}", aggregator);
686+
}
687+
OutputTarget::Stderr => {
688+
eprintln!("{}", aggregator);
689+
}
690+
}
691+
692+
// Reset the aggregator
693+
aggregator.reset();
694+
}
695+
}
671696
}
672697

673698
#[async_trait]
@@ -683,26 +708,42 @@ impl Actor for LogClientActor {
683708
let (aggregate_window_tx, aggregate_window_rx) =
684709
watch::channel(DEFAULT_AGGREGATE_WINDOW_SEC);
685710

711+
// Initialize aggregators
712+
let mut aggregators = HashMap::new();
713+
aggregators.insert(OutputTarget::Stderr, Aggregator::new());
714+
aggregators.insert(OutputTarget::Stdout, Aggregator::new());
715+
let aggregators = Arc::new(RwLock::new(aggregators));
716+
717+
// Clone aggregators for the aggregator task
718+
let aggregators_for_task = Arc::clone(&aggregators);
719+
686720
// Start the loggregator
687-
let aggregator_handle =
688-
{ tokio::spawn(async move { start_aggregator(log_rx, aggregate_window_rx).await }) };
721+
let aggregator_handle = tokio::spawn(async move {
722+
start_aggregator(log_rx, aggregate_window_rx, aggregators_for_task).await
723+
});
689724

690725
Ok(Self {
691726
log_tx,
692727
aggregator_handle,
693728
aggregate_window_tx,
694-
should_aggregate: false,
729+
should_aggregate: true,
730+
aggregators,
695731
})
696732
}
697733
}
698734

735+
impl Drop for LogClientActor {
736+
fn drop(&mut self) {
737+
// Flush the remaining logs before shutting down
738+
Self::print_aggregators(&self.aggregators);
739+
}
740+
}
741+
699742
async fn start_aggregator(
700743
mut log_rx: mpsc::Receiver<(OutputTarget, String)>,
701744
mut interval_sec_rx: watch::Receiver<u64>,
745+
aggregators: Arc<RwLock<HashMap<OutputTarget, Aggregator>>>,
702746
) -> anyhow::Result<()> {
703-
let mut aggregators = HashMap::new();
704-
aggregators.insert(OutputTarget::Stderr, Aggregator::new());
705-
aggregators.insert(OutputTarget::Stdout, Aggregator::new());
706747
let mut interval =
707748
tokio::time::interval(tokio::time::Duration::from_secs(*interval_sec_rx.borrow()));
708749

@@ -711,7 +752,8 @@ async fn start_aggregator(
711752
tokio::select! {
712753
// Process incoming log messages
713754
Some((output_target, log_line)) = log_rx.recv() => {
714-
if let Some(aggregator) = aggregators.get_mut(&output_target) {
755+
let mut aggregators_guard = aggregators.write().unwrap();
756+
if let Some(aggregator) = aggregators_guard.get_mut(&output_target) {
715757
if let Err(e) = aggregator.add_line(&log_line) {
716758
tracing::error!("error adding log line: {}", e);
717759
}
@@ -726,24 +768,14 @@ async fn start_aggregator(
726768

727769
// Every interval tick, print and reset the aggregator
728770
_ = interval.tick() => {
729-
for (output_target, aggregator) in aggregators.iter_mut() {
730-
if aggregator.is_empty() {
731-
continue;
732-
}
733-
if output_target == &OutputTarget::Stdout {
734-
println!("{}", aggregator);
735-
} else {
736-
eprintln!("{}", aggregator);
737-
}
738-
739-
// Reset the aggregator
740-
aggregator.reset();
741-
}
771+
LogClientActor::print_aggregators(&aggregators);
742772
}
743773

744774
// Exit if the channel is closed
745775
else => {
746776
tracing::error!("log channel closed, exiting aggregator");
777+
// Print final aggregated logs before shutting down
778+
LogClientActor::print_aggregators(&aggregators);
747779
break;
748780
}
749781
}

monarch_extension/src/logging.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ impl LoggingMeshClient {
9292
}
9393
}
9494

95+
impl Drop for LoggingMeshClient {
96+
fn drop(&mut self) {
97+
let _ = self.client_actor.drain_and_stop().unwrap();
98+
}
99+
}
100+
95101
pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
96102
module.add_class::<LoggingMeshClient>()?;
97103
Ok(())

python/tests/actor_log_flush_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import asyncio
11+
import logging
12+
import sys
13+
14+
from monarch.actor import Actor, endpoint, proc_mesh
15+
16+
17+
class Printer(Actor):
18+
def __init__(self) -> None:
19+
self.logger: logging.Logger = logging.getLogger()
20+
21+
@endpoint
22+
async def print(self, content: str) -> None:
23+
print(f"{content}", flush=True)
24+
sys.stdout.flush()
25+
sys.stderr.flush()
26+
27+
28+
async def main() -> None:
29+
pm = await proc_mesh(gpus=2)
30+
# never flush
31+
await pm.logging_option(aggregate_window_sec=1000)
32+
am = await pm.spawn("printer", Printer)
33+
34+
# These should be streamed to client
35+
for _ in range(5):
36+
await am.print.call("has print streaming")
37+
38+
# Sleep a tiny so we allow the logs to stream back to the client
39+
await asyncio.sleep(1)
40+
41+
42+
if __name__ == "__main__":
43+
asyncio.run(main())

python/tests/test_python_actors.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@
66

77
# pyre-unsafe
88
import asyncio
9+
import importlib
910
import logging
1011
import operator
1112
import os
1213
import re
14+
import subprocess
1315
import sys
1416
import tempfile
1517
import threading
@@ -749,6 +751,37 @@ async def test_logging_option_defaults() -> None:
749751
pass
750752

751753

754+
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
755+
@pytest.mark.oss_skip
756+
async def test_flush_logs_fast_exit() -> None:
757+
# We use a subprocess to run the test so we can handle the flushed logs at the end.
758+
# Otherwise, it is hard to restore the original stdout/stderr.
759+
with (
760+
importlib.resources.as_file(
761+
importlib.resources.files(__package__).joinpath("actor_log_flush_test")
762+
) as test_log_main,
763+
):
764+
if not test_log_main.exists():
765+
raise ImportError("cannot find file binary test_log_main")
766+
767+
# Run the binary in a separate process and capture stdout and stderr
768+
process = subprocess.Popen(
769+
[test_log_main],
770+
stdout=subprocess.PIPE,
771+
text=True,
772+
)
773+
774+
# Wait for the process to complete and get output
775+
stdout, _ = process.communicate(timeout=60)
776+
777+
# Check if the process ended without error
778+
if process.returncode != 0:
779+
raise RuntimeError(f"Process ended with error code {process.returncode}. ")
780+
781+
# Assertions on the captured output
782+
assert "has print streaming" in stdout, stdout
783+
784+
752785
class SendAlot(Actor):
753786
@endpoint
754787
async def send(self, port: Port[int]):

0 commit comments

Comments
 (0)