diff --git a/monarch_conda/Cargo.toml b/monarch_conda/Cargo.toml new file mode 100644 index 000000000..3e536ce3c --- /dev/null +++ b/monarch_conda/Cargo.toml @@ -0,0 +1,38 @@ +# @generated by autocargo from //monarch/monarch_conda:[conda-sync-cli,monarch_conda] + +[package] +name = "monarch_conda" +version = "0.0.0" +authors = ["Meta"] +edition = "2021" +license = "BSD-3-Clause" + +[[bin]] +name = "conda_sync_cli" +path = "src/main.rs" + +[dependencies] +anyhow = "1.0.98" +async-tempfile = "0.7.0" +bincode = "1.3.3" +chrono = { version = "0.4.41", features = ["clock", "serde", "std"], default-features = false } +clap = { version = "4.5.41", features = ["derive", "env", "string", "unicode", "wrap_help"] } +dashmap = { version = "5.5.3", features = ["rayon", "serde"] } +digest = "0.10" +filetime = "0.2.25" +futures = { version = "0.3.31", features = ["async-await", "compat"] } +globset = { version = "0.4.13", features = ["serde1"] } +ignore = "0.4" +itertools = "0.14.0" +memchr = "2.7.5" +memmap2 = "0.9.5" +rattler_conda_types = "0.28.3" +serde = { version = "1.0.219", features = ["derive", "rc"] } +serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "unbounded_depth"] } +sha2 = "0.10.6" +tokio = { version = "1.46.1", features = ["full", "test-util", "tracing"] } +tokio-util = { version = "0.7.15", features = ["full"] } +walkdir = "2.3" + +[dev-dependencies] +tempfile = "3.15" diff --git a/monarch_conda/src/diff.rs b/monarch_conda/src/diff.rs new file mode 100644 index 000000000..1b969fbae --- /dev/null +++ b/monarch_conda/src/diff.rs @@ -0,0 +1,158 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::HashMap; +use std::path::Path; +use std::path::PathBuf; +use std::time::Duration; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use anyhow::Context; +use anyhow::Result; +use anyhow::ensure; +use chrono::DateTime; +use chrono::Utc; +use digest::Digest; +use digest::Output; +use rattler_conda_types::PrefixRecord; +use rattler_conda_types::prefix_record::PathsEntry; +use serde::Deserialize; +use serde::Serialize; +use serde_json; +use sha2::Sha256; +use tokio::fs; +use walkdir::WalkDir; + +use crate::hash_utils; +use crate::pack_meta::History; +use crate::pack_meta::Offsets; + +/// Fingerprint of the conda-meta directory, used by `CondaFingerprint` below. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct CondaMetaFingerprint { + // TODO(agallagher): It might be worth storing more information of installed + // packages, so that we could print better error messages when we detect two + // envs are not equivalent. + hash: Output, +} + +impl CondaMetaFingerprint { + async fn from_env(path: &Path) -> Result { + let mut hasher = Sha256::new(); + hash_utils::hash_directory_tree(&path.join("conda-meta"), &mut hasher).await?; + Ok(Self { + hash: hasher.finalize(), + }) + } +} + +/// Fingerprint of the pack-meta directory, used by `CondaFingerprint` below. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct PackMetaFingerprint { + offsets: Output, + pub history: History, +} + +impl PackMetaFingerprint { + async fn from_env(path: &Path) -> Result { + let pack_meta = path.join("pack-meta"); + + // Read the fulle history.jsonl file. + let contents = fs::read_to_string(pack_meta.join("history.jsonl")).await?; + let history = History::from_contents(&contents)?; + + // Read entire offsets.jsonl file, but avoid hashing the offsets, which can change. + let mut hasher = Sha256::new(); + let contents = fs::read_to_string(pack_meta.join("offsets.jsonl")).await?; + let offsets = Offsets::from_contents(&contents)?; + for ent in offsets.entries { + let contents = bincode::serialize(&(ent.path, ent.mode, ent.offsets.len()))?; + hasher.update(contents.len().to_le_bytes()); + hasher.update(&contents); + } + let offsets = hasher.finalize(); + + Ok(Self { history, offsets }) + } +} + +/// A fingerprint of a conda environment, used to detect if two envs are similar enough to +/// facilitate mtime-based conda syncing. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +pub struct CondaFingerprint { + pub conda_meta: CondaMetaFingerprint, + pub pack_meta: PackMetaFingerprint, +} + +impl CondaFingerprint { + pub async fn from_env(path: &Path) -> Result { + Ok(Self { + conda_meta: CondaMetaFingerprint::from_env(path).await?, + pack_meta: PackMetaFingerprint::from_env(path).await?, + }) + } + + /// Create a comparator to compare the mtimes of files from two "equivalent" conda envs. + /// In particular, thie comparator will be aware of spuriuos mtime changes that occurs from + /// prefix replacement (via `meta-pack`), and will filter them out. + pub fn mtime_comparator( + a: &Self, + b: &Self, + ) -> Result std::cmp::Ordering + Send + Sync>> { + let (a_prefix, a_base) = a.pack_meta.history.first()?; + let (b_prefix, b_base) = b.pack_meta.history.first()?; + ensure!(a_prefix == b_prefix); + + // NOTE(agallagher): There appears to be some mtime drift on some files after fbpkg creation, + // so acccount for that here. + let slop = Duration::from_secs(5 * 60); + + // We load the timestamp from the first history entry, and use this to see if any + // files have been updated since the env was created. + let a_base = UNIX_EPOCH + Duration::from_secs(a_base) + slop; + let b_base = UNIX_EPOCH + Duration::from_secs(b_base) + slop; + + // We also load the last prefix update window for each, as any mtimes from this window + // should be ignored. + let a_window = a + .pack_meta + .history + .prefix_and_last_update_window()? + .1 + .map(|(s, e)| { + ( + UNIX_EPOCH + Duration::from_secs(s), + UNIX_EPOCH + Duration::from_secs(e + 1), + ) + }); + let b_window = b + .pack_meta + .history + .prefix_and_last_update_window()? + .1 + .map(|(s, e)| { + ( + UNIX_EPOCH + Duration::from_secs(s), + UNIX_EPOCH + Duration::from_secs(e + 1), + ) + }); + + Ok(Box::new(move |a: &SystemTime, b: &SystemTime| { + match ( + *a > a_base && a_window.is_none_or(|(s, e)| *a < s || *a > e), + *b > b_base && b_window.is_none_or(|(s, e)| *b < s || *b > e), + ) { + (true, false) => std::cmp::Ordering::Greater, + (false, true) => std::cmp::Ordering::Less, + (false, false) => std::cmp::Ordering::Equal, + (true, true) => a.cmp(b), + } + })) + } +} diff --git a/monarch_conda/src/hash_utils.rs b/monarch_conda/src/hash_utils.rs new file mode 100644 index 000000000..f0b74bed3 --- /dev/null +++ b/monarch_conda/src/hash_utils.rs @@ -0,0 +1,162 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::path::Path; + +use anyhow::Result; +use anyhow::bail; +use digest::Digest; +use tokio::fs; +use walkdir::WalkDir; + +/// Compute a hash of a directory tree using the provided hasher. +/// +/// This function traverses the directory tree deterministically (sorted by file name) +/// and includes both file paths and file contents in the hash computation. +/// +/// # Arguments +/// * `dir` - The directory to hash +/// * `hasher` - A hasher implementing the Digest trait (e.g., Sha256::new()) +/// +/// # Returns +/// () - The hasher is updated with the directory tree data +pub async fn hash_directory_tree(dir: &Path, hasher: &mut D) -> Result<()> { + // Iterate entries with deterministic ordering + for entry in WalkDir::new(dir).sort_by_file_name().into_iter() { + let entry = entry?; + let path = entry.path(); + let relative_path = path.strip_prefix(dir)?; + + // Hash the relative path (normalized to use forward slashes) + let path_str = relative_path.to_string_lossy().replace('\\', "/"); + hasher.update(path_str.as_bytes()); + hasher.update(b"\0"); // null separator + + if entry.file_type().is_file() { + // Hash file type marker, size, and contents + hasher.update(b"FILE:"); + let contents = fs::read(path).await?; + hasher.update(contents.len().to_le_bytes()); + hasher.update(&contents); + } else if entry.file_type().is_dir() { + // For directories, hash a type marker + hasher.update(b"DIR:"); + } else if entry.file_type().is_symlink() { + // For symlinks, hash type marker, target size, and target + hasher.update(b"SYMLINK:"); + let target = fs::read_link(path).await?; + let target_string = target.to_string_lossy().into_owned(); + let target_bytes = target_string.as_bytes(); + hasher.update(target_bytes.len().to_le_bytes()); + hasher.update(target_bytes); + } else { + // Unexpected file type + bail!("Unexpected file type for path: {}", path.display()); + } + + hasher.update(b"\n"); // entry separator + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use sha2::Sha256; + use tempfile::TempDir; + use tokio::fs; + + use super::*; + + #[tokio::test] + async fn test_hash_directory_tree() -> Result<()> { + // Create a temporary directory with some test files + let temp_dir = TempDir::new()?; + let dir_path = temp_dir.path(); + + // Create test files + fs::write(dir_path.join("file1.txt"), "Hello, world!").await?; + fs::write(dir_path.join("file2.txt"), "Another file").await?; + fs::create_dir(dir_path.join("subdir")).await?; + fs::write(dir_path.join("subdir").join("file3.txt"), "Nested file").await?; + + // Hash the directory + let mut hasher1 = Sha256::new(); + let mut hasher2 = Sha256::new(); + hash_directory_tree(dir_path, &mut hasher1).await?; + hash_directory_tree(dir_path, &mut hasher2).await?; + + let hash1 = hasher1.finalize(); + let hash2 = hasher2.finalize(); + + // Should be deterministic + assert_eq!(hash1, hash2); + assert_eq!(hash1.len(), 32); // SHA256 raw bytes length + + Ok(()) + } + + #[tokio::test] + async fn test_no_hash_collision_between_file_and_dir() -> Result<()> { + // Test that a file containing "DIR:" and an empty directory don't collide + let temp_dir1 = TempDir::new()?; + let temp_dir2 = TempDir::new()?; + + // Create a file with content that could collide with directory marker + fs::write(temp_dir1.path().join("test"), "DIR:").await?; + + // Create an empty directory with the same name + fs::create_dir(temp_dir2.path().join("test")).await?; + + // Hash both scenarios + let mut hasher_file = Sha256::new(); + let mut hasher_dir = Sha256::new(); + hash_directory_tree(temp_dir1.path(), &mut hasher_file).await?; + hash_directory_tree(temp_dir2.path(), &mut hasher_dir).await?; + + let hash_file = hasher_file.finalize(); + let hash_dir = hasher_dir.finalize(); + + // Should be different due to type prefixes + assert_ne!(hash_file, hash_dir); + + Ok(()) + } + + #[tokio::test] + async fn test_no_structural_marker_collision() -> Result<()> { + // Test that files containing our structural markers don't cause collisions + let temp_dir1 = TempDir::new()?; + let temp_dir2 = TempDir::new()?; + + // Create a file that could potentially collide without size prefixes: + // Path: "test1", Content: "foo\n" + // Without size prefixes: test1\0FILE:foo\n\n + fs::write(temp_dir1.path().join("test1"), "foo\n").await?; + + // Create a file with path that includes our structural markers: + // Path: "test1\nFILE:", Content: "foo\n" + // Without size prefixes: test1\nFILE:\0FILE:foo\n\n + // This could potentially collide with the above + fs::write(temp_dir2.path().join("test1\nFILE:"), "foo\n").await?; + + // Hash both scenarios + let mut hasher1 = Sha256::new(); + let mut hasher2 = Sha256::new(); + hash_directory_tree(temp_dir1.path(), &mut hasher1).await?; + hash_directory_tree(temp_dir2.path(), &mut hasher2).await?; + + let hash1 = hasher1.finalize(); + let hash2 = hasher2.finalize(); + + // Should be different - size prefixes prevent structural marker confusion + assert_ne!(hash1, hash2); + + Ok(()) + } +} diff --git a/monarch_conda/src/lib.rs b/monarch_conda/src/lib.rs new file mode 100644 index 000000000..6aa170ec2 --- /dev/null +++ b/monarch_conda/src/lib.rs @@ -0,0 +1,14 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#![feature(once_cell_try)] + +pub mod diff; +pub mod hash_utils; +pub mod pack_meta; +pub mod sync; diff --git a/monarch_conda/src/main.rs b/monarch_conda/src/main.rs new file mode 100644 index 000000000..404af6a81 --- /dev/null +++ b/monarch_conda/src/main.rs @@ -0,0 +1,43 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#![feature(exit_status_error)] + +use std::path::PathBuf; + +use anyhow::Result; +use clap::Parser; +use futures::try_join; +use monarch_conda::sync::receiver; +use monarch_conda::sync::sender; + +#[derive(Parser)] +#[command(name = "conda-sync")] +#[command(about = "A tool to diff conda environments")] +struct Args { + /// Path to the source conda environment + src: PathBuf, + /// Path to the dest conda environment + dst: PathBuf, +} + +#[tokio::main] +async fn main() -> Result<()> { + let args = Args::parse(); + + // Receiver -> Sender + let (recv, send) = tokio::io::duplex(5 * 1024 * 1024); + let (from_receiver, to_receiver) = tokio::io::split(recv); + let (from_sender, to_sender) = tokio::io::split(send); + try_join!( + receiver(&args.dst, from_sender, to_sender), + sender(&args.src, from_receiver, to_receiver), + )?; + + Ok(()) +} diff --git a/monarch_conda/src/sync.rs b/monarch_conda/src/sync.rs new file mode 100644 index 000000000..6fa2d9e4a --- /dev/null +++ b/monarch_conda/src/sync.rs @@ -0,0 +1,1421 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::cell::OnceCell; +use std::cmp::Ordering; +use std::collections::BTreeMap; +use std::collections::BTreeSet; +use std::collections::HashMap; +use std::io::ErrorKind; +use std::os::unix::fs::MetadataExt; +use std::os::unix::fs::PermissionsExt; +use std::path::Path; +use std::path::PathBuf; +use std::sync::Arc; +use std::sync::mpsc::channel; +use std::time::Duration; +use std::time::SystemTime; +use std::time::UNIX_EPOCH; + +use anyhow::Context; +use anyhow::Result; +use anyhow::bail; +use anyhow::ensure; +use async_tempfile::TempFile; +use dashmap::DashMap; +use dashmap::mapref::entry::Entry; +use filetime::FileTime; +use futures::SinkExt; +use futures::StreamExt; +use futures::try_join; +use globset::Glob; +use globset::GlobSet; +use globset::GlobSetBuilder; +use ignore::DirEntry; +use ignore::WalkBuilder; +use ignore::WalkState; +use itertools::Itertools; +use memchr::memmem::Finder; +use memmap2::MmapMut; +use serde::Deserialize; +use serde::Serialize; +use tokio::fs; +use tokio::io::AsyncRead; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWrite; +use tokio::io::AsyncWriteExt; +use tokio_util::codec::FramedRead; +use tokio_util::codec::FramedWrite; +use tokio_util::codec::LengthDelimitedCodec; + +use crate::diff::CondaFingerprint; + +#[derive(Eq, PartialEq)] +enum Origin { + Src, + Dst, +} + +#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] +enum FileTypeInfo { + Directory, + File(bool), + Symlink, +} + +impl FileTypeInfo { + fn same(&self, other: &FileTypeInfo) -> bool { + match (self, other) { + (FileTypeInfo::Directory, FileTypeInfo::Directory) => true, + (FileTypeInfo::File(_), FileTypeInfo::File(_)) => true, + (FileTypeInfo::Symlink, FileTypeInfo::Symlink) => true, + _ => false, + } + } +} + +#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)] +struct Metadata { + mtime: SystemTime, + ftype: FileTypeInfo, +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum Receive { + File(bool), + Symlink, +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum Action { + Delete(bool), + Directory, + Receive(SystemTime, Receive), +} + +#[derive(Debug, Serialize, Deserialize)] +struct FileSectionHeader { + num: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +struct FileHeader { + path: PathBuf, + symlink: bool, +} + +#[derive(Debug, Serialize, Deserialize)] +enum FileContents { + Symlink(PathBuf), + File(u64), +} + +#[derive(Debug, Serialize, Deserialize)] +struct FileContentsHeader { + path: PathBuf, + contents: FileContents, +} + +#[derive(Debug, Serialize, Deserialize)] +enum FileList { + Entry(PathBuf, Metadata), + Done, +} + +struct ActionsBuilder { + ignores: Option, + state: DashMap, + actions: DashMap, + mtime_comparator: Box Ordering + Send + Sync + 'static>, +} + +impl ActionsBuilder { + fn new_with( + ignores: Option, + mtime_comparator: Box Ordering + Send + Sync + 'static>, + ) -> Self { + Self { + ignores, + state: DashMap::new(), + actions: DashMap::new(), + mtime_comparator, + } + } + + fn process(&self, origin: Origin, path: PathBuf, metadata: Metadata) -> Result<()> { + match self.state.entry(path) { + Entry::Occupied(val) => { + let (path, (existing_origin, existing_metadata)) = val.remove_entry(); + if let Some(ignores) = &self.ignores { + if ignores.is_match(path.as_path()) { + return Ok(()); + } + } + ensure!(existing_origin != origin); + let (src, dst) = match origin { + Origin::Dst => (existing_metadata, metadata), + Origin::Src => (metadata, existing_metadata), + }; + if src.ftype == FileTypeInfo::Directory && dst.ftype == FileTypeInfo::Directory { + // --omit-dir-times + } else { + match (self.mtime_comparator)(&src.mtime, &dst.mtime) { + Ordering::Less => bail!("{}: dst newer than src", path.display()), + Ordering::Equal => { + ensure!( + src.ftype.same(&dst.ftype), + "{}: {:?} != {:?}", + path.display(), + dst, + src + ); + } + Ordering::Greater => { + self.actions.insert( + path, + match src.ftype { + FileTypeInfo::File(executable) => { + Action::Receive(src.mtime, Receive::File(executable)) + } + FileTypeInfo::Symlink => { + Action::Receive(src.mtime, Receive::Symlink) + } + FileTypeInfo::Directory => Action::Directory, + }, + ); + } + } + } + } + Entry::Vacant(entry) => { + entry.insert((origin, metadata)); + } + } + Ok(()) + } + + fn process_src(&self, path: PathBuf, metadata: Metadata) -> Result<()> { + self.process(Origin::Src, path, metadata) + } + + fn process_dst(&self, path: PathBuf, metadata: Metadata) -> Result<()> { + self.process(Origin::Dst, path, metadata) + } + + fn into_actions(self) -> HashMap { + let mut actions: HashMap<_, _> = self.actions.into_iter().collect(); + for (path, (origin, metadata)) in self.state.into_iter() { + match origin { + Origin::Src => { + if let Some(ignores) = &self.ignores { + if ignores.is_match(path.as_path()) { + continue; + } + } + actions.insert( + path, + match metadata.ftype { + FileTypeInfo::File(executable) => { + Action::Receive(metadata.mtime, Receive::File(executable)) + } + FileTypeInfo::Directory => Action::Directory, + FileTypeInfo::Symlink => { + Action::Receive(metadata.mtime, Receive::Symlink) + } + }, + ); + } + Origin::Dst => { + actions.insert( + path, + Action::Delete(matches!(metadata.ftype, FileTypeInfo::Directory)), + ); + } + } + } + actions + } +} + +fn walk_dir< + E: Into, + F: Fn(PathBuf, Metadata) -> Result<(), E> + Sync + Send + 'static, +>( + src: PathBuf, + callback: F, +) -> Result<()> { + let (error_tx, error_rx) = channel(); + + let src_handle = src.clone(); + let handle_ent = move |entry: DirEntry| -> Result<()> { + let metadata = entry.metadata()?; + callback( + entry + .path() + .strip_prefix(src_handle.clone()) + .context("sub path")? + .to_path_buf(), + Metadata { + mtime: UNIX_EPOCH + + Duration::new( + metadata.mtime().try_into()?, + metadata.mtime_nsec().try_into()?, + ), + ftype: if metadata.file_type().is_file() { + let mode = metadata.permissions().mode(); + FileTypeInfo::File(mode & 0o100 != 0) + } else if metadata.file_type().is_dir() { + FileTypeInfo::Directory + } else if metadata.file_type().is_symlink() { + FileTypeInfo::Symlink + } else { + bail!("unexpected file type") + }, + }, + ) + .map_err(Into::into)?; + Ok(()) + }; + + WalkBuilder::new(src) + .standard_filters(true) + .same_file_system(true) + .build_parallel() + .run(|| { + Box::new(|ent| match ent.map_err(Into::into).and_then(&handle_ent) { + Ok(()) => WalkState::Continue, + Err(err) => { + error_tx.clone().send(err).unwrap(); + WalkState::Quit + } + }) + }); + + match error_rx.try_recv() { + Ok(err) => Err(err), + _ => Ok(()), + } +} + +pub async fn sender( + src: &Path, + from_receiver: impl AsyncRead + Unpin, + to_receiver: impl AsyncWrite + Unpin, +) -> Result<()> { + let mut to_receiver = FramedWrite::new(to_receiver, LengthDelimitedCodec::new()); + let mut from_receiver = FramedRead::new(from_receiver, LengthDelimitedCodec::new()); + + let (ent_tx, mut ent_rx) = tokio::sync::mpsc::unbounded_channel(); + let src_clone = src.to_path_buf(); + try_join!( + async { + tokio::task::spawn_blocking(move || { + walk_dir(src_clone.clone(), move |path, ent| ent_tx.send((path, ent))) + }) + .await? + }, + async { + // Send conda env fingerprint + let src_env = CondaFingerprint::from_env(src).await?; + to_receiver + .send(bincode::serialize(&src_env)?.into()) + .await + .context("sending src conda fingerprint")?; + to_receiver.flush().await?; + + // Send file lists. + while let Some((path, metadata)) = ent_rx.recv().await { + to_receiver + .send(bincode::serialize(&FileList::Entry(path, metadata))?.into()) + .await + .context("sending file ent")?; + } + to_receiver + .send(bincode::serialize(&FileList::Done)?.into()) + .await + .context("sending file list end")?; + to_receiver.flush().await?; + + anyhow::Ok(()) + }, + )?; + + // Convert back to raw stream to send file header + contents. + to_receiver.flush().await?; + let mut to_receiver = to_receiver.into_inner(); + + let hdr: FileSectionHeader = + bincode::deserialize(&from_receiver.next().await.context("header")??)?; + for _ in 0..hdr.num { + let FileHeader { path, symlink } = + bincode::deserialize(&from_receiver.next().await.context("signature")??)?; + let fpath = src.join(&path); + if symlink { + let header = FileContentsHeader { + path, + contents: FileContents::Symlink(fs::read_link(&fpath).await?), + }; + let header = bincode::serialize(&header)?; + to_receiver.write_all(&header.len().to_le_bytes()).await?; + to_receiver + .write_all(&header) + .await + .context("sending sig header")?; + } else { + let mut base = fs::File::open(src.join(&path)).await?; + let header = FileContentsHeader { + path, + contents: FileContents::File(base.metadata().await?.len()), + }; + let header = bincode::serialize(&header)?; + to_receiver.write_all(&header.len().to_le_bytes()).await?; + to_receiver + .write_all(&header) + .await + .context("sending sig header")?; + tokio::io::copy(&mut base, &mut to_receiver).await?; + } + } + to_receiver.flush().await?; + + Ok(()) +} + +async fn persist(tmp: TempFile, path: &Path) -> Result<(), std::io::Error> { + // Atomic rename the temp file into its final location. + match fs::rename(tmp.file_path(), &path).await { + Err(err) if err.kind() == ErrorKind::IsADirectory => { + async { + fs::remove_dir(&path).await?; + fs::rename(tmp.file_path(), &path).await + } + .await + } + other => other, + }?; + tmp.drop_async().await; + Ok(()) +} + +/// Helper function to set the FileTime for every file, symlink, and directory in a directory tree +async fn set_mtime(path: &Path, mtime: SystemTime) -> Result<(), std::io::Error> { + let mtime = FileTime::from_system_time(mtime); + filetime::set_symlink_file_times(path, mtime.clone(), mtime)?; + Ok(()) +} + +async fn make_executable(path: &Path) -> Result<(), std::io::Error> { + let metadata = fs::metadata(path).await?; + let mut permissions = metadata.permissions(); + let mode = permissions.mode(); + permissions.set_mode(mode | 0o111); + fs::set_permissions(path, permissions).await?; + Ok(()) +} + +fn null_pad(input: &[u8], len: usize) -> Result> { + ensure!(input.len() <= len, "Input is longer than target length"); + let mut padded = Vec::with_capacity(len); + padded.extend_from_slice(input); + padded.resize(len, 0); + Ok(padded) +} + +fn replace_bytestring(vec: &mut Vec, from: &[u8], to: &[u8]) { + if vec.len() >= from.len() { + let mut i = 0; + while i <= vec.len() - from.len() { + if &vec[i..i + from.len()] == from { + vec.splice(i..i + from.len(), to.iter().cloned()); + i += to.len(); // Skip past the inserted section + } else { + i += 1; + } + } + } +} + +fn is_binary(buf: &[u8]) -> bool { + // If any null byte is seen, treat as binary + if buf.iter().contains(&0) { + return true; + } + + // Count non-printable characters (excluding common control chars) + let non_print = buf + .iter() + .filter(|&&b| !(b == b'\n' || b == b'\r' || b == b'\t' || (0x20..=0x7E).contains(&b))) + .count(); + + // If more than 30%, consider binary + non_print * 100 > buf.len() * 30 +} + +pub async fn receiver( + dst: &Path, + from_sender: impl AsyncRead + Unpin, + to_sender: impl AsyncWrite + Unpin, +) -> Result> { + let mut to_sender = FramedWrite::new(to_sender, LengthDelimitedCodec::new()); + let mut from_sender = FramedRead::new(from_sender, LengthDelimitedCodec::new()); + + // Get the conda env fingerprint for the src and dst, and use that to create a + // comparator we can use to compare the mtimes between them. + let dst_env = CondaFingerprint::from_env(dst).await?; + let src_env: CondaFingerprint = + bincode::deserialize(&from_sender.next().await.context("fingerprint")??)?; + let comparator = CondaFingerprint::mtime_comparator(&src_env, &dst_env)?; + let ignores = GlobSetBuilder::new() + .add(Glob::new("**/*.pyc")?) + .add(Glob::new("**/__pycache__/")?) + .add(Glob::new("**/__pycache__/**/*")?) + .build()?; + let actions_builder = Arc::new(ActionsBuilder::new_with(Some(ignores), comparator)); + + // Process file lists from src/dst. + try_join!( + // Walk destination to grab file list. + async { + let dst = dst.to_path_buf(); + let actions_builder = actions_builder.clone(); + tokio::task::spawn_blocking(move || { + walk_dir(dst, move |path, ent| { + actions_builder + .process_dst(path.clone(), ent) + .with_context(|| format!("{}", path.display())) + }) + }) + .await??; + anyhow::Ok(()) + }, + // Process file list sent from sender. + async { + while let FileList::Entry(path, metadata) = + bincode::deserialize(&from_sender.next().await.context("file list")??)? + { + actions_builder + .process_src(path.clone(), metadata) + .with_context(|| format!("{}", path.display()))?; + } + anyhow::Ok(()) + } + )?; + let actions = Arc::into_inner(actions_builder) + .expect("should be done") + .into_actions(); + + // Demultiplex FS actions. + let mut dirs = BTreeSet::new(); + let mut deletions = BTreeMap::new(); + let mut files = HashMap::new(); + for (path, action) in actions.iter() { + let path = path.clone(); + match action { + Action::Directory => { + dirs.insert(path); + } + Action::Receive(mtime, recv) => { + files.insert(path, (*mtime, recv)); + } + Action::Delete(is_dir) => { + deletions.insert(path, *is_dir); + } + } + } + + try_join!( + async { + // Process deletions first. + for (path, is_dir) in deletions.into_iter().rev() { + let fpath = dst.join(path); + if is_dir { + fs::remove_dir(&fpath).await + } else { + fs::remove_file(&fpath).await + } + .with_context(|| format!("deleting {}", fpath.display()))?; + } + + // Then create dirs. + for path in dirs.into_iter() { + let fpath = dst.join(path); + match fs::remove_file(&fpath).await { + Err(err) if err.kind() == ErrorKind::NotFound => Ok(()), + other => other, + } + .with_context(|| format!("clearing path {}", fpath.display()))?; + fs::create_dir(&fpath) + .await + .with_context(|| format!("creating dir {}", fpath.display()))?; + } + + let src_prefix = src_env.pack_meta.history.last_prefix()?; + let dst_prefix = dst_env.pack_meta.history.last_prefix()?; + let src_prefix_bytes = src_prefix.as_os_str().as_encoded_bytes(); + let dst_prefix_bytes = dst_prefix.as_os_str().as_encoded_bytes(); + let replacement_prefix = OnceCell::new(); + let finder = Finder::new(src_prefix_bytes); + + // Then pull file data and create files. + let mut from_sender = from_sender.into_inner(); + for _ in 0..files.len() { + // Read a file header. + let len = from_sender.read_u64_le().await?; + let mut buf = vec![0u8; len as usize]; + from_sender.read_exact(&mut buf).await?; + let FileContentsHeader { path, contents } = + bincode::deserialize(&buf).context("delta header")?; + let fpath = dst.join(&path); + match (contents, files.get(&path).context("missing file")?) { + // Read file contents and write to a tempfile. + (FileContents::File(len), (mtime, Receive::File(executable))) => { + let mut dst_tmp = + TempFile::new_in(fpath.parent().context("parent")?).await?; + let mut reader = (&mut from_sender).take(len); + + // Copy the file contents. + if src_prefix == dst_prefix { + tokio::io::copy(&mut reader, &mut dst_tmp).await?; + } else { + // We do different copies dependending on whether the file is binary or not. + let mut buf = vec![0; 4096]; + let len = reader.read(&mut buf[..]).await?; + buf.truncate(len); + if is_binary(&buf) { + dst_tmp.write_all(&buf).await?; + tokio::io::copy(&mut reader, &mut dst_tmp).await?; + + // For binary files Replace prefixes. + // SAFETY: use mmap for fast in-place prefix replacement + let mut mmap = unsafe { MmapMut::map_mut(&*dst_tmp)? }; + let mut offset = 0; + while let Some(pos) = finder.find(&mmap[offset..]) { + let trailing_byte = + &mmap[offset + pos + src_prefix_bytes.len()..][..1]; + if matches!(trailing_byte, b"" | b"/" | b"\0") { + let repl = replacement_prefix.get_or_try_init(|| { + null_pad(dst_prefix_bytes, src_prefix_bytes.len()) + })?; + mmap[offset + pos..offset + pos + src_prefix_bytes.len()] + .copy_from_slice(repl); + } + offset = pos + src_prefix_bytes.len(); + } + } else { + reader.read_to_end(&mut buf).await?; + // Replace prefixes. + replace_bytestring(&mut buf, src_prefix_bytes, dst_prefix_bytes); + dst_tmp.write_all(&buf).await?; + } + } + + if *executable { + make_executable(dst_tmp.file_path()).await?; + } + persist(dst_tmp, &fpath).await?; + set_mtime(&fpath, *mtime).await?; + } + (FileContents::Symlink(mut target), (mtime, Receive::Symlink)) => { + if let Ok(suffix) = target.strip_prefix(src_prefix) { + target = dst_prefix.join(suffix); + } + fs::symlink(target, &fpath).await?; + set_mtime(&fpath, *mtime).await?; + } + _ => bail!("unexpected file contents"), + } + } + anyhow::Ok(()) + }, + async { + to_sender + .send(bincode::serialize(&FileSectionHeader { num: files.len() })?.into()) + .await + .context("sending sig section header")?; + for (path, (_, recv)) in files.iter() { + to_sender + .send( + bincode::serialize(&FileHeader { + path: path.clone(), + symlink: matches!(recv, Receive::Symlink), + })? + .into(), + ) + .await + .context("sending sig header")?; + } + to_sender.flush().await?; + anyhow::Ok(()) + }, + )?; + + Ok(actions) +} + +pub async fn sync(src: &Path, dst: &Path) -> Result> { + // Receiver -> Sender + let (recv, send) = tokio::io::duplex(5 * 1024 * 1024); + let (from_receiver, to_receiver) = tokio::io::split(recv); + let (from_sender, to_sender) = tokio::io::split(send); + let (actions, ()) = try_join!( + receiver(dst, from_sender, to_sender), + sender(src, from_receiver, to_receiver), + )?; + Ok(actions) +} + +#[cfg(test)] +#[allow(clippy::disallowed_methods)] +mod tests { + use std::collections::HashMap; + use std::os::unix::fs::PermissionsExt; + use std::path::Path; + use std::path::PathBuf; + use std::time::Duration; + use std::time::SystemTime; + + use anyhow::Result; + use rattler_conda_types::package::FileMode; + use tempfile::TempDir; + use tokio::fs; + + use super::Action; + use super::make_executable; + use super::set_mtime; + use super::sync; + use crate::pack_meta::History; + use crate::pack_meta::HistoryRecord; + use crate::pack_meta::Offset; + use crate::pack_meta::OffsetRecord; + use crate::pack_meta::Offsets; + use crate::sync::Receive; + + /// Helper function to create a basic conda environment structure + async fn setup_conda_env>( + dirpath: P, + mtime: SystemTime, + prefix: Option<&str>, + ) -> Result

{ + let env_path = dirpath.as_ref(); + + // Create the basic directory structure + fs::create_dir_all(&env_path).await?; + fs::create_dir(&env_path.join("conda-meta")).await?; + fs::create_dir(&env_path.join("pack-meta")).await?; + + // Create a basic conda-meta file to establish fingerprint + add_file( + env_path, + "conda-meta/history", + "==> 2023-01-01 00:00:00 <==\npackage install actions\n", + mtime, + false, + ) + .await?; + + // Create a basic package record + add_file( + env_path, + "conda-meta/package-1.0-0.json", + r#"{ + "name": "package", + "version": "1.0", + "build": "0", + "build_number": 0, + "paths_data": { + "paths": [ + { + "path": "bin/test-file", + "path_type": "hardlink", + "size_in_bytes": 10, + "mode": "text" + } + ] + }, + "repodata_record": { + "package_record": { + "timestamp": 1672531200 + } + } + }"#, + mtime, + false, + ) + .await?; + + // Create offsets.jsonl + let offsets = Offsets { + entries: vec![OffsetRecord { + path: PathBuf::from("bin/test-file"), + mode: FileMode::Text, + offsets: vec![Offset { + start: 0, + len: 10, + contents: None, + }], + }], + }; + add_file( + env_path, + "pack-meta/offsets.jsonl", + &offsets.to_str()?, + mtime, + false, + ) + .await?; + + // Create the actual file referenced in the metadata + fs::create_dir(env_path.join("bin")).await?; + add_file(env_path, "bin/test-file", "test data\n", mtime, false).await?; + + // Create a file that was prefix-updated after the package was installed. + let window = ( + mtime + Duration::from_secs(20), + mtime + Duration::from_secs(25), + ); + fs::create_dir(env_path.join("lib")).await?; + add_file( + env_path, + "lib/libfoo.so", + "libfoo.so contents\n", + window.0 + Duration::from_secs(5), + false, + ) + .await?; + + // Use provided prefix or default to "base" + let prefix_path = PathBuf::from(prefix.unwrap_or("base")); + + // Create history.jsonl + let history = History { + entries: vec![ + HistoryRecord { + timestamp: mtime.duration_since(SystemTime::UNIX_EPOCH)?.as_secs(), + prefix: PathBuf::from("/conda/prefix"), + finished: true, + }, + HistoryRecord { + timestamp: window.0.duration_since(SystemTime::UNIX_EPOCH)?.as_secs(), + prefix: prefix_path.clone(), + finished: false, + }, + HistoryRecord { + timestamp: window.1.duration_since(SystemTime::UNIX_EPOCH)?.as_secs(), + prefix: prefix_path, + finished: true, + }, + ], + }; + add_file( + env_path, + "pack-meta/history.jsonl", + &history.to_str()?, + mtime, + false, + ) + .await?; + + Ok(dirpath) + } + + /// Helper function to modify a file in the conda environment + async fn modify_file( + env_path: &Path, + file_path: &str, + content: &str, + mtime: SystemTime, + ) -> Result<()> { + let full_path = env_path.join(file_path); + fs::write(&full_path, content).await?; + + // Set the file time + set_mtime(&full_path, mtime).await?; + + Ok(()) + } + + /// Helper function to add a new file to the conda environment + async fn add_file( + env_path: &Path, + file_path: &str, + content: &str, + mtime: SystemTime, + executable: bool, + ) -> Result<()> { + let full_path = env_path.join(file_path); + fs::write(&full_path, content).await?; + + if executable { + make_executable(&full_path).await?; + } + + // Set the file time + set_mtime(&full_path, mtime).await?; + + Ok(()) + } + + /// Helper function to verify file content + async fn verify_file_content(path1: &Path, path2: &Path) -> Result { + let content1 = fs::read_to_string(path1).await?; + let content2 = fs::read_to_string(path2).await?; + Ok(content1 == content2) + } + + /// Helper function to verify file permissions + async fn verify_file_permissions(path: &Path, expected_executable: bool) -> Result { + let metadata = fs::metadata(path).await?; + let mode = metadata.permissions().mode(); + let is_executable = mode & 0o111 != 0; + Ok(is_executable == expected_executable) + } + + /// Helper function to verify symlink target + async fn verify_symlink_target(path: &Path, expected_target: &Path) -> Result { + let target = fs::read_link(path).await?; + Ok(target == expected_target) + } + + #[tokio::test] + async fn test_sync_modified_file() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); // 2023-01-01 00:00:00 UTC + + // Setup identical conda environments + let src_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + + // Modify a file in the source environment + let modified_content = "modified test data\n"; + let newer_time = base_time + Duration::from_secs(3600); // 1 hour later + modify_file( + src_env.path(), + "bin/test-file", + modified_content, + newer_time, + ) + .await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Create expected actions map + let expected_actions = HashMap::from([( + PathBuf::from("bin/test-file"), + Action::Receive(newer_time, Receive::File(false)), + )]); + + // Verify the entire actions map + assert_eq!(actions, expected_actions); + + // Verify the file was updated in the destination + assert!( + verify_file_content( + &src_env.path().join("bin/test-file"), + &dst_env.path().join("bin/test-file") + ) + .await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_new_file() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); // 2023-01-01 00:00:00 UTC + + // Setup identical conda environments + let src_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + + // Add a new file to the source environment + let new_file_content = "new file content\n"; + let newer_time = base_time + Duration::from_secs(3600); // 1 hour later + add_file( + src_env.path(), + "lib/new-file.txt", + new_file_content, + newer_time, + false, + ) + .await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Create expected actions map + let expected_actions = HashMap::from([ + //(PathBuf::from("lib"), Action::Directory(newer_time)), + ( + PathBuf::from("lib/new-file.txt"), + Action::Receive(newer_time, Receive::File(false)), + ), + ]); + + // Verify the entire actions map + assert_eq!(actions, expected_actions); + + // Verify the new file was created in the destination + assert!( + verify_file_content( + &src_env.path().join("lib/new-file.txt"), + &dst_env.path().join("lib/new-file.txt") + ) + .await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_directory_creation() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); // 2023-01-01 00:00:00 UTC + + // Setup identical conda environments + let src_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + + // Create a new directory with a file in the source environment + let newer_time = base_time + Duration::from_secs(3600); // 1 hour later + fs::create_dir(src_env.path().join("new_dir")).await?; + add_file( + src_env.path(), + "new_dir/test.txt", + "test content", + newer_time, + false, + ) + .await?; + set_mtime(&src_env.path().join("new_dir"), newer_time).await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Create expected actions map + let expected_actions = HashMap::from([ + (PathBuf::from("new_dir"), Action::Directory), + ( + PathBuf::from("new_dir/test.txt"), + Action::Receive(newer_time, Receive::File(false)), + ), + ]); + + // Verify the entire actions map + assert_eq!(actions, expected_actions); + + // Verify the directory was created in the destination + assert!(dst_env.path().join("new_dir").exists()); + + // Verify the file was created in the destination + assert!( + verify_file_content( + &src_env.path().join("new_dir/test.txt"), + &dst_env.path().join("new_dir/test.txt") + ) + .await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_symlink() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); // 2023-01-01 00:00:00 UTC + + // Setup identical conda environments + let src_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + + // Create a symlink in the source environment + fs::symlink("bin/test-file", src_env.path().join("link-to-test")).await?; + + // Set a newer time for the symlink to ensure it's synced + let newer_time = base_time + Duration::from_secs(3600); // 1 hour later + set_mtime(&src_env.path().join("link-to-test"), newer_time).await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Create expected actions map + let expected_actions = HashMap::from([( + PathBuf::from("link-to-test"), + Action::Receive(newer_time, Receive::Symlink), + )]); + + // Verify the entire actions map + assert_eq!(actions, expected_actions); + + // Verify the symlink was created in the destination + assert!(dst_env.path().join("link-to-test").exists()); + + // Verify the symlink target + assert!( + verify_symlink_target( + &dst_env.path().join("link-to-test"), + &PathBuf::from("bin/test-file") + ) + .await? + ); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_file_deletion() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); // 2023-01-01 00:00:00 UTC + + // Setup identical conda environments + let src_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + + // Add an extra file to the destination that doesn't exist in source + let newer_time = base_time + Duration::from_secs(3600); // 1 hour later + add_file( + dst_env.path(), + "extra-file.txt", + "should be deleted", + newer_time, + false, + ) + .await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Create expected actions map + let expected_actions = + HashMap::from([(PathBuf::from("extra-file.txt"), Action::Delete(false))]); + + // Verify the entire actions map + assert_eq!(actions, expected_actions); + + // Verify the extra file was deleted from the destination + assert!(!dst_env.path().join("extra-file.txt").exists()); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_ignores_pyc_files() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); // 2023-01-01 00:00:00 UTC + + // Setup identical conda environments + let src_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + + // Add a .pyc file to the source. + let newer_time = base_time + Duration::from_secs(3600); // 1 hour later + add_file( + src_env.path(), + "lib/test.pyc", + "compiled python", + newer_time, + false, + ) + .await?; + + // Add a file in __pycache__ directory to the destination + fs::create_dir(dst_env.path().join("lib/__pycache__")).await?; + add_file( + dst_env.path(), + "lib/__pycache__/cached.pyc", + "cached python", + newer_time, + false, + ) + .await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // For this test, we expect an empty actions map since .pyc files are ignored + let expected_actions = HashMap::from([ + (PathBuf::from("lib/__pycache__"), Action::Delete(true)), + ( + PathBuf::from("lib/__pycache__/cached.pyc"), + Action::Delete(false), + ), + ]); + + // Verify the entire actions map + assert_eq!(actions, expected_actions); + + // Verify the .pyc files were deleted (they should be ignored) + assert!(!dst_env.path().join("lib/test.pyc").exists()); + assert!(!dst_env.path().join("lib/__pycache__").exists()); + assert!(!dst_env.path().join("lib/__pycache__/cached.pyc").exists()); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_executable_permissions() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); // 2023-01-01 00:00:00 UTC + + // Setup identical conda environments + let src_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, None).await?; + + // Add an executable file to the source + let newer_time = base_time + Duration::from_secs(3600); // 1 hour later + add_file( + src_env.path(), + "bin/executable", + "#!/bin/sh\necho hello", + newer_time, + true, + ) + .await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Create expected actions map + let expected_actions = HashMap::from([( + PathBuf::from("bin/executable"), + Action::Receive(newer_time, Receive::File(true)), + )]); + + // Verify the entire actions map + assert_eq!(actions, expected_actions); + + // Verify the file was created in the destination + assert!(dst_env.path().join("bin/executable").exists()); + + // Verify the file content was synced correctly + assert!( + verify_file_content( + &src_env.path().join("bin/executable"), + &dst_env.path().join("bin/executable") + ) + .await? + ); + + // Verify the executable permissions were preserved + assert!(verify_file_permissions(&dst_env.path().join("bin/executable"), true).await?); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_text_file_prefix_replacement() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); // 2023-01-01 00:00:00 UTC + + // Setup conda environments with different prefixes + let src_prefix = "/opt/conda/src"; + let dst_prefix = "/opt/conda/dst"; + let src_env = setup_conda_env(TempDir::new()?, base_time, Some(src_prefix)).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, Some(dst_prefix)).await?; + + // Add a text file with prefix references to the source + let newer_time = base_time + Duration::from_secs(3600); + let text_content = format!( + "#!/bin/bash\nexport PATH={}/bin:$PATH\necho 'Using prefix: {}'\n", + src_prefix, src_prefix + ); + add_file( + src_env.path(), + "bin/script.sh", + &text_content, + newer_time, + true, + ) + .await?; + + // Add an empty file too. + add_file(src_env.path(), "bin/script2.sh", "", newer_time, true).await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Verify the file was synced + let expected_actions = HashMap::from([ + ( + PathBuf::from("bin/script.sh"), + Action::Receive(newer_time, Receive::File(true)), + ), + ( + PathBuf::from("bin/script2.sh"), + Action::Receive(newer_time, Receive::File(true)), + ), + ]); + assert_eq!(actions, expected_actions); + + // Verify the prefix was replaced in the destination file + let dst_content = fs::read_to_string(dst_env.path().join("bin/script.sh")).await?; + let expected_content = format!( + "#!/bin/bash\nexport PATH={}/bin:$PATH\necho 'Using prefix: {}'\n", + dst_prefix, dst_prefix + ); + assert_eq!(dst_content, expected_content); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_binary_file_prefix_replacement() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); + + // Setup conda environments with different prefixes + let src_prefix = "/opt/conda/src"; + let dst_prefix = "/opt/conda/dst"; + let src_env = setup_conda_env(TempDir::new()?, base_time, Some(src_prefix)).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, Some(dst_prefix)).await?; + + // Create a binary file with embedded prefix and null bytes + let newer_time = base_time + Duration::from_secs(3600); + let mut binary_content = Vec::new(); + binary_content.extend_from_slice(b"\x7fELF"); // ELF magic number + binary_content.extend_from_slice(&[0u8; 10]); // null bytes to make it binary + binary_content.extend_from_slice(src_prefix.as_bytes()); + binary_content.extend_from_slice(&[0u8; 20]); // more null bytes + binary_content.extend_from_slice(b"end"); + + fs::write(src_env.path().join("lib/binary"), &binary_content).await?; + set_mtime(&src_env.path().join("lib/binary"), newer_time).await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Verify the file was synced + let expected_actions = HashMap::from([( + PathBuf::from("lib/binary"), + Action::Receive(newer_time, Receive::File(false)), + )]); + assert_eq!(actions, expected_actions); + + // Verify the prefix was replaced in the binary file with null padding + let dst_content = fs::read(dst_env.path().join("lib/binary")).await?; + + // The original file size should be preserved + assert_eq!(dst_content.len(), binary_content.len()); + + // Check that the prefix was replaced + let dst_content_str = String::from_utf8_lossy(&dst_content); + assert!(dst_content_str.contains(dst_prefix)); + assert!(!dst_content_str.contains(src_prefix)); + + // Verify the ELF header and end marker are still present + assert!(dst_content.starts_with(b"\x7fELF")); + assert!(dst_content.ends_with(b"end")); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_symlink_prefix_replacement() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); + + // Setup conda environments with different prefixes + let src_prefix = "/opt/conda/src"; + let dst_prefix = "/opt/conda/dst"; + let src_env = setup_conda_env(TempDir::new()?, base_time, Some(src_prefix)).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, Some(dst_prefix)).await?; + + // Create a symlink that points to a path with the source prefix + let newer_time = base_time + Duration::from_secs(3600); + let symlink_target = format!("{}/lib/target-file", src_prefix); + fs::symlink(&symlink_target, src_env.path().join("bin/link-to-target")).await?; + set_mtime(&src_env.path().join("bin/link-to-target"), newer_time).await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Verify the symlink was synced + let expected_actions = HashMap::from([( + PathBuf::from("bin/link-to-target"), + Action::Receive(newer_time, Receive::Symlink), + )]); + assert_eq!(actions, expected_actions); + + // Verify the symlink target was updated with the destination prefix + let dst_target = fs::read_link(dst_env.path().join("bin/link-to-target")).await?; + let expected_target = PathBuf::from(format!("{}/lib/target-file", dst_prefix)); + assert_eq!(dst_target, expected_target); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_symlink_no_prefix_replacement() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); + + // Setup conda environments with different prefixes + let src_prefix = "/opt/conda/src"; + let dst_prefix = "/opt/conda/dst"; + let src_env = setup_conda_env(TempDir::new()?, base_time, Some(src_prefix)).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, Some(dst_prefix)).await?; + + // Create a symlink that points to a relative path (should not be modified) + let newer_time = base_time + Duration::from_secs(3600); + let symlink_target = "relative/path/target"; + fs::symlink(&symlink_target, src_env.path().join("bin/relative-link")).await?; + set_mtime(&src_env.path().join("bin/relative-link"), newer_time).await?; + + // Sync changes from source to destination + let actions = sync(src_env.path(), dst_env.path()).await?; + + // Verify the symlink was synced + let expected_actions = HashMap::from([( + PathBuf::from("bin/relative-link"), + Action::Receive(newer_time, Receive::Symlink), + )]); + assert_eq!(actions, expected_actions); + + // Verify the symlink target was NOT modified (since it doesn't start with src_prefix) + let dst_target = fs::read_link(dst_env.path().join("bin/relative-link")).await?; + let expected_target = PathBuf::from(symlink_target); + assert_eq!(dst_target, expected_target); + + Ok(()) + } + + #[tokio::test] + async fn test_sync_binary_file_prefix_replacement_fails_when_dst_longer() -> Result<()> { + // Set base time for consistent file timestamps + let base_time = SystemTime::UNIX_EPOCH + Duration::from_secs(1672531200); + + // Setup conda environments where destination prefix is longer than source + let src_prefix = "/opt/src"; // Short prefix + let dst_prefix = "/opt/very/long/destination/prefix"; // Much longer prefix + let src_env = setup_conda_env(TempDir::new()?, base_time, Some(src_prefix)).await?; + let dst_env = setup_conda_env(TempDir::new()?, base_time, Some(dst_prefix)).await?; + + // Create a binary file with embedded prefix and null bytes + let newer_time = base_time + Duration::from_secs(3600); + let mut binary_content = Vec::new(); + binary_content.extend_from_slice(b"\x7fELF"); // ELF magic number + binary_content.extend_from_slice(&[0u8; 10]); // null bytes to make it binary + binary_content.extend_from_slice(src_prefix.as_bytes()); + binary_content.extend_from_slice(&[0u8; 20]); // more null bytes + binary_content.extend_from_slice(b"end"); + + fs::write(src_env.path().join("lib/binary"), &binary_content).await?; + set_mtime(&src_env.path().join("lib/binary"), newer_time).await?; + + // Sync changes from source to destination - this should fail + let result = sync(src_env.path(), dst_env.path()).await; + + // Verify that the sync operation failed due to the destination prefix being longer + assert!(result.is_err()); + let error_msg = result.unwrap_err().to_string(); + assert!(error_msg.contains("Input is longer than target length")); + + Ok(()) + } +} diff --git a/monarch_extension/src/code_sync.rs b/monarch_extension/src/code_sync.rs index 4b225b107..aba9e1530 100644 --- a/monarch_extension/src/code_sync.rs +++ b/monarch_extension/src/code_sync.rs @@ -12,13 +12,13 @@ use std::path::PathBuf; use anyhow::Result; use futures::TryFutureExt; -use futures::future::try_join_all; use hyperactor_mesh::Mesh; use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::shared_cell::SharedCell; use monarch_hyperactor::code_sync::WorkspaceLocation; use monarch_hyperactor::code_sync::manager::CodeSyncManager; use monarch_hyperactor::code_sync::manager::CodeSyncManagerParams; +use monarch_hyperactor::code_sync::manager::CodeSyncMethod; use monarch_hyperactor::code_sync::manager::WorkspaceConfig; use monarch_hyperactor::code_sync::manager::WorkspaceShape; use monarch_hyperactor::code_sync::manager::code_sync_mesh; @@ -118,6 +118,44 @@ impl RemoteWorkspace { } } +#[pyclass( + frozen, + name = "CodeSyncMethod", + module = "monarch._rust_bindings.monarch_extension.code_sync" +)] +#[derive(Clone, Debug, Serialize, Deserialize)] +enum PyCodeSyncMethod { + Rsync, + CondaSync, +} + +impl From for CodeSyncMethod { + fn from(method: PyCodeSyncMethod) -> CodeSyncMethod { + match method { + PyCodeSyncMethod::Rsync => CodeSyncMethod::Rsync, + PyCodeSyncMethod::CondaSync => CodeSyncMethod::CondaSync, + } + } +} + +#[pymethods] +impl PyCodeSyncMethod { + #[staticmethod] + fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult { + bincode::deserialize(bytes.as_bytes()) + .map_err(|e| PyErr::new::(e.to_string())) + } + + fn __reduce__<'py>( + slf: &Bound<'py, Self>, + ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> { + let bytes = bincode::serialize(&*slf.borrow()) + .map_err(|e| PyErr::new::(e.to_string()))?; + let py_bytes = PyBytes::new(slf.py(), &bytes); + Ok((slf.as_any().getattr("from_bytes")?, (py_bytes,))) + } +} + #[pyclass( frozen, name = "WorkspaceConfig", @@ -127,14 +165,19 @@ impl RemoteWorkspace { struct PyWorkspaceConfig { local: PathBuf, remote: RemoteWorkspace, + method: PyCodeSyncMethod, } #[pymethods] impl PyWorkspaceConfig { #[new] - #[pyo3(signature = (*, local, remote))] - fn new(local: PathBuf, remote: RemoteWorkspace) -> Self { - Self { local, remote } + #[pyo3(signature = (*, local, remote, method = PyCodeSyncMethod::Rsync))] + fn new(local: PathBuf, remote: RemoteWorkspace, method: PyCodeSyncMethod) -> Self { + Self { + local, + remote, + method, + } } } @@ -152,6 +195,7 @@ impl CodeSyncMeshClient { actor_mesh: SharedCell>, local: PathBuf, remote: RemoteWorkspace, + method: CodeSyncMethod, auto_reload: bool, ) -> Result<()> { let actor_mesh = actor_mesh.borrow()?; @@ -164,7 +208,9 @@ impl CodeSyncMeshClient { location: remote.location.into(), shape, }; - code_sync_mesh(&actor_mesh, local, remote, auto_reload).await?; + code_sync_mesh(&actor_mesh, local, remote, method, auto_reload) + .await + .map_err(|err| PyRuntimeError::new_err(format!("{:#?}", err)))?; Ok(()) } } @@ -183,12 +229,13 @@ impl CodeSyncMeshClient { })? } - #[pyo3(signature = (*, local, remote, auto_reload = false))] + #[pyo3(signature = (*, local, remote, method = PyCodeSyncMethod::Rsync, auto_reload = false))] fn sync_workspace<'py>( &self, py: Python<'py>, local: PathBuf, remote: RemoteWorkspace, + method: PyCodeSyncMethod, auto_reload: bool, ) -> PyResult> { monarch_hyperactor::runtime::future_into_py( @@ -197,6 +244,7 @@ impl CodeSyncMeshClient { self.actor_mesh.clone(), local, remote, + method.into(), auto_reload, ) .err_into(), @@ -213,14 +261,19 @@ impl CodeSyncMeshClient { let actor_mesh = self.actor_mesh.clone(); monarch_hyperactor::runtime::future_into_py( py, - try_join_all(workspaces.into_iter().map(|workspace| { - CodeSyncMeshClient::sync_workspace_( - actor_mesh.clone(), - workspace.local, - workspace.remote, - auto_reload, - ) - })) + async move { + for workspace in workspaces.into_iter() { + CodeSyncMeshClient::sync_workspace_( + actor_mesh.clone(), + workspace.local, + workspace.remote, + workspace.method.into(), + auto_reload, + ) + .await? + } + anyhow::Ok(()) + } .err_into(), ) } @@ -228,6 +281,7 @@ impl CodeSyncMeshClient { pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { module.add_class::()?; + module.add_class::()?; module.add_class::()?; module.add_class::()?; module.add_class::()?; diff --git a/monarch_hyperactor/Cargo.toml b/monarch_hyperactor/Cargo.toml index 4138eaa94..8612369e3 100644 --- a/monarch_hyperactor/Cargo.toml +++ b/monarch_hyperactor/Cargo.toml @@ -25,7 +25,9 @@ hyperactor_mesh = { version = "0.0.0", path = "../hyperactor_mesh" } hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiprocess" } hyperactor_telemetry = { version = "0.0.0", path = "../hyperactor_telemetry" } inventory = "0.3.8" +lazy_errors = "0.10.1" lazy_static = "1.5" +monarch_conda = { version = "0.0.0", path = "../monarch_conda" } monarch_types = { version = "0.0.0", path = "../monarch_types" } ndslice = { version = "0.0.0", path = "../ndslice" } nix = { version = "0.30.1", features = ["dir", "event", "hostname", "inotify", "ioctl", "mman", "mount", "net", "poll", "ptrace", "reboot", "resource", "sched", "signal", "term", "time", "user", "zerocopy"] } diff --git a/monarch_hyperactor/src/code_sync.rs b/monarch_hyperactor/src/code_sync.rs index 7e318b67c..4047b1823 100644 --- a/monarch_hyperactor/src/code_sync.rs +++ b/monarch_hyperactor/src/code_sync.rs @@ -7,6 +7,7 @@ */ pub mod auto_reload; +pub mod conda_sync; pub mod manager; pub mod rsync; mod workspace; diff --git a/monarch_hyperactor/src/code_sync/conda_sync.rs b/monarch_hyperactor/src/code_sync/conda_sync.rs new file mode 100644 index 000000000..b77b6b451 --- /dev/null +++ b/monarch_hyperactor/src/code_sync/conda_sync.rs @@ -0,0 +1,178 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::HashMap; +use std::path::PathBuf; + +use anyhow::Result; +use async_trait::async_trait; +use futures::FutureExt; +use futures::StreamExt; +use futures::TryStreamExt; +use hyperactor::Actor; +use hyperactor::Bind; +use hyperactor::Handler; +use hyperactor::Named; +use hyperactor::PortRef; +use hyperactor::Unbind; +use hyperactor_mesh::actor_mesh::ActorMesh; +use hyperactor_mesh::connect::Connect; +use hyperactor_mesh::connect::accept; +use hyperactor_mesh::sel; +use lazy_errors::ErrorStash; +use lazy_errors::OrStash; +use lazy_errors::StashedResult; +use lazy_errors::TryCollectOrStash; +use monarch_conda::sync::Action; +use monarch_conda::sync::receiver; +use monarch_conda::sync::sender; +use ndslice::Selection; +use serde::Deserialize; +use serde::Serialize; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; + +use crate::code_sync::WorkspaceLocation; + +/// Represents the result of an conda sync operation with details about what was transferred +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Named)] +pub struct CondaSyncResult { + /// All changes that occurred during the sync operation + pub changes: HashMap, +} + +#[derive(Debug, Clone, Named, Serialize, Deserialize, Bind, Unbind)] +pub struct CondaSyncMessage { + /// The connect message to create a duplex bytestream with the client. + pub connect: PortRef, + /// A port to send back the result or any errors. + pub result: PortRef>, + /// The location of the workspace to sync. + pub workspace: WorkspaceLocation, +} + +#[derive(Debug, Named, Serialize, Deserialize)] +pub struct CondaSyncParams {} + +#[derive(Debug)] +#[hyperactor::export(spawn = true, handlers = [CondaSyncMessage { cast = true }])] +pub struct CondaSyncActor {} + +#[async_trait] +impl Actor for CondaSyncActor { + type Params = CondaSyncParams; + + async fn new(CondaSyncParams {}: Self::Params) -> Result { + Ok(Self {}) + } +} + +#[async_trait] +impl Handler for CondaSyncActor { + async fn handle( + &mut self, + cx: &hyperactor::Context, + CondaSyncMessage { + workspace, + connect, + result, + }: CondaSyncMessage, + ) -> Result<(), anyhow::Error> { + let res = async { + let workspace = workspace.resolve()?; + let (connect_msg, completer) = Connect::allocate(cx.self_id().clone(), cx); + connect.send(cx, connect_msg)?; + let (mut read, mut write) = completer.complete().await?.into_split(); + let changes_result = receiver(&workspace, &mut read, &mut write).await; + + // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable + // message spam. + write.shutdown().await?; + let mut buf = vec![]; + read.read_to_end(&mut buf).await?; + + anyhow::Ok(CondaSyncResult { + changes: changes_result?, + }) + } + .await; + result.send(cx, res.map_err(|e| format!("{:#?}", e)))?; + Ok(()) + } +} + +pub async fn conda_sync_mesh( + actor_mesh: &M, + local_workspace: PathBuf, + remote_workspace: WorkspaceLocation, +) -> Result> +where + M: ActorMesh, +{ + let mailbox = actor_mesh.proc_mesh().client(); + let (conns_tx, conns_rx) = mailbox.open_port::(); + + let (res1, res2) = futures::future::join( + conns_rx + .take(actor_mesh.shape().slice().len()) + .err_into::() + .try_for_each_concurrent(None, |connect| async { + let (mut read, mut write) = accept(mailbox, mailbox.actor_id().clone(), connect) + .await? + .into_split(); + let res = sender(&local_workspace, &mut read, &mut write).await; + + // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable + // message spam. + write.shutdown().await?; + let mut buf = vec![]; + read.read_to_end(&mut buf).await?; + + res + }) + .boxed(), + async move { + let (result_tx, result_rx) = mailbox.open_port::>(); + actor_mesh.cast( + sel!(*), + CondaSyncMessage { + connect: conns_tx.bind(), + result: result_tx.bind(), + workspace: remote_workspace, + }, + )?; + + // Wait for all actors to report result. + let results = result_rx + .take(actor_mesh.shape().slice().len()) + .try_collect::>() + .await?; + + // Combine all errors into one. + let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "remote failures"); + match results + .into_iter() + .map(|res| res.map_err(anyhow::Error::msg)) + .try_collect_or_stash::>(&mut errs) + { + StashedResult::Ok(results) => anyhow::Ok(results), + StashedResult::Err(_) => Err(errs.into_result().unwrap_err().into()), + } + }, + ) + .await; + + // Combine code sync handler and cast errors into one. + let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "code sync failed"); + res1.or_stash(&mut errs); + if let StashedResult::Ok(results) = res2.or_stash(&mut errs) { + errs.into_result()?; + return Ok(results); + } + Err(errs.into_result().unwrap_err().into()) +} diff --git a/monarch_hyperactor/src/code_sync/manager.rs b/monarch_hyperactor/src/code_sync/manager.rs index 07898c2f1..904be9ee5 100644 --- a/monarch_hyperactor/src/code_sync/manager.rs +++ b/monarch_hyperactor/src/code_sync/manager.rs @@ -39,11 +39,16 @@ use hyperactor_mesh::reference::ActorMeshId; use hyperactor_mesh::reference::ActorMeshRef; use hyperactor_mesh::reference::ProcMeshId; use hyperactor_mesh::sel; +use lazy_errors::ErrorStash; +use lazy_errors::TryCollectOrStash; +use monarch_conda::sync::sender; use ndslice::Selection; use ndslice::Shape; use ndslice::ShapeError; use serde::Deserialize; use serde::Serialize; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; use tokio::net::TcpListener; use tokio::net::TcpStream; @@ -51,6 +56,10 @@ use crate::code_sync::WorkspaceLocation; use crate::code_sync::auto_reload::AutoReloadActor; use crate::code_sync::auto_reload::AutoReloadMessage; use crate::code_sync::auto_reload::AutoReloadParams; +use crate::code_sync::conda_sync::CondaSyncActor; +use crate::code_sync::conda_sync::CondaSyncMessage; +use crate::code_sync::conda_sync::CondaSyncParams; +use crate::code_sync::conda_sync::CondaSyncResult; use crate::code_sync::rsync::RsyncActor; use crate::code_sync::rsync::RsyncDaemon; use crate::code_sync::rsync::RsyncMessage; @@ -60,6 +69,7 @@ use crate::code_sync::rsync::RsyncResult; #[derive(Clone, Serialize, Deserialize, Debug)] pub enum Method { Rsync { connect: PortRef }, + CondaSync { connect: PortRef }, } /// Describe the shape of the workspace. @@ -167,6 +177,7 @@ pub struct CodeSyncManagerParams {} pub struct CodeSyncManager { rsync: OnceCell>, auto_reload: OnceCell>, + conda_sync: OnceCell>, } #[async_trait] @@ -177,6 +188,7 @@ impl Actor for CodeSyncManager { Ok(Self { rsync: OnceCell::new(), auto_reload: OnceCell::new(), + conda_sync: OnceCell::new(), }) } } @@ -199,6 +211,15 @@ impl CodeSyncManager { .get_or_try_init(AutoReloadActor::spawn(cx, AutoReloadParams {})) .await } + + async fn get_conda_sync_actor<'a>( + &'a mut self, + cx: &Context<'a, Self>, + ) -> Result<&'a ActorHandle> { + self.conda_sync + .get_or_try_init(CondaSyncActor::spawn(cx, CondaSyncParams {})) + .await + } } #[async_trait] @@ -226,6 +247,20 @@ impl CodeSyncMessageHandler for CodeSyncManager { // Observe any errors. let _ = rx.recv().await?.map_err(anyhow::Error::msg)?; } + Method::CondaSync { connect } => { + // Forward rsync connection port to the RsyncActor, which will do the actual + // connection and run the client. + let (tx, mut rx) = cx.open_port::>(); + self.get_conda_sync_actor(cx) + .await? + .send(CondaSyncMessage { + connect, + result: tx.bind(), + workspace, + })?; + // Observe any errors. + let _ = rx.recv().await?.map_err(anyhow::Error::msg)?; + } } // Trigger hot reload on all ranks that use/share this workspace. @@ -296,10 +331,17 @@ impl CodeSyncMessageHandler for CodeSyncManager { } } +#[derive(Debug, Clone, Copy, Serialize, Deserialize)] +pub enum CodeSyncMethod { + Rsync, + CondaSync, +} + pub async fn code_sync_mesh( actor_mesh: &RootActorMesh<'_, CodeSyncManager>, local_workspace: PathBuf, remote_workspace: WorkspaceConfig, + method: CodeSyncMethod, auto_reload: bool, ) -> Result<()> { let mailbox = actor_mesh.proc_mesh().client(); @@ -307,36 +349,82 @@ pub async fn code_sync_mesh( // Create a slice of the actor mesh that only includes workspace "owners" (e.g. on multi-GPU hosts, // only one of the ranks on that host will participate in the code sync). let actor_mesh = SlicedActorMesh::new(actor_mesh, remote_workspace.shape.owners()?); + let shape = actor_mesh.shape().clone(); + + let (method, method_fut) = match method { + CodeSyncMethod::Rsync => { + // Spawn a rsync daemon to accept incoming connections from actors. + let daemon = + RsyncDaemon::spawn(TcpListener::bind(("::1", 0)).await?, &local_workspace).await?; + let daemon_addr = daemon.addr().clone(); + let (rsync_conns_tx, rsync_conns_rx) = mailbox.open_port::(); + ( + Method::Rsync { + connect: rsync_conns_tx.bind(), + }, + // This async task will process rsync connection attempts concurrently, forwarding + // them to the rsync daemon above. + async move { + let res = rsync_conns_rx + .take(shape.slice().len()) + .err_into::() + .try_for_each_concurrent(None, |connect| async move { + let (mut local, mut stream) = try_join!( + TcpStream::connect(daemon_addr.clone()).err_into(), + accept(mailbox, mailbox.actor_id().clone(), connect), + )?; + tokio::io::copy_bidirectional(&mut local, &mut stream).await?; + Ok(()) + }) + .await; + daemon.shutdown().await?; + res?; + anyhow::Ok(()) + } + .boxed(), + ) + } + CodeSyncMethod::CondaSync => { + let (conns_tx, conns_rx) = mailbox.open_port::(); + ( + Method::CondaSync { + connect: conns_tx.bind(), + }, + async move { + conns_rx + .take(shape.slice().len()) + .err_into::() + .try_for_each_concurrent(None, |connect| async { + let (mut read, mut write) = + accept(mailbox, mailbox.actor_id().clone(), connect) + .await? + .into_split(); + let res = sender(&local_workspace, &mut read, &mut write).await; + + // Shutdown our end, then read from the other end till exhaustion to avoid undeliverable + // message spam. + write.shutdown().await?; + let mut buf = vec![]; + read.read_to_end(&mut buf).await?; + + res + }) + .await + } + .boxed(), + ) + } + }; - // Spawn a rsync daemon to accept incoming connections from actors. - let daemon = RsyncDaemon::spawn(TcpListener::bind(("::1", 0)).await?, &local_workspace).await?; - let daemon_addr = daemon.addr(); - - let (rsync_conns_tx, rsync_conns_rx) = mailbox.open_port::(); - let ((), ()) = try_join!( - // This async task will process rsync connection attempts concurrently, forwarding them to - // the rsync daemon above. - rsync_conns_rx - .take(actor_mesh.shape().slice().len()) - .err_into::() - .try_for_each_concurrent(None, |connect| async move { - let (mut local, mut stream) = try_join!( - TcpStream::connect(daemon_addr.clone()).err_into(), - accept(mailbox, mailbox.actor_id().clone(), connect), - )?; - tokio::io::copy_bidirectional(&mut local, &mut stream).await?; - Ok(()) - }) - .boxed(), + let (res1, res2) = futures::future::join( + method_fut, // This async task will cast the code sync message to workspace owners, and process any errors. async move { let (result_tx, result_rx) = mailbox.open_port::>(); actor_mesh.cast( sel!(*), CodeSyncMessage::Sync { - method: Method::Rsync { - connect: rsync_conns_tx.bind(), - }, + method, workspace: remote_workspace.location.clone(), reload: if auto_reload { Some(remote_workspace.shape) @@ -346,18 +434,30 @@ pub async fn code_sync_mesh( result: result_tx.bind(), }, )?; - let _: Vec<()> = result_rx + + // Wait for all actors to report result. + let results = result_rx .take(actor_mesh.shape().slice().len()) - .map(|res| res?.map_err(anyhow::Error::msg)) - .try_collect() + .try_collect::>() .await?; - Ok(()) - }, - )?; - daemon.shutdown().await?; - - Ok(()) + // Combine all errors into one. + let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "remote failures"); + results + .into_iter() + .map(|res| res.map_err(anyhow::Error::msg)) + .try_collect_or_stash::<()>(&mut errs); + Ok(errs.into_result()?) + }, + ) + .await; + + // Combine code sync handler and cast errors into one. + let mut errs = ErrorStash::<_, _, anyhow::Error>::new(|| "code sync failed"); + [res1, res2] + .into_iter() + .try_collect_or_stash::<()>(&mut errs); + Ok(errs.into_result()?) } #[cfg(test)] @@ -486,6 +586,7 @@ mod tests { &actor_mesh, source_workspace.path().to_path_buf(), remote_workspace_config.clone(), + CodeSyncMethod::Rsync, false, // no auto-reload ) .await?; diff --git a/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi b/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi index 6f0c1922f..66413522c 100644 --- a/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/code_sync.pyi @@ -39,6 +39,14 @@ class WorkspaceShape: @staticmethod def exclusive() -> "WorkspaceShape": ... +class CodeSyncMethod: + """ + Python binding for the Rust CodeSyncMethod enum. + """ + + Rsync: CodeSyncMethod + CondaSync: CodeSyncMethod + @final class RemoteWorkspace: """ @@ -46,6 +54,19 @@ class RemoteWorkspace: """ def __init__(self, location: WorkspaceLocation, shape: WorkspaceShape) -> None: ... +@final +class WorkspaceConfig: + """ + Python binding for the Rust WorkspaceConfig struct. + """ + def __init__( + self, + *, + local: Path, + remote: RemoteWorkspace, + method: CodeSyncMethod = ..., + ) -> None: ... + @final class CodeSyncMeshClient: """ @@ -65,6 +86,6 @@ class CodeSyncMeshClient: async def sync_workspaces( self, *, - workspaces: list[tuple[str, RemoteWorkspace]], + workspaces: list[WorkspaceConfig], auto_reload: bool = False, ) -> None: ... diff --git a/python/monarch/_src/actor/code_sync/__init__.py b/python/monarch/_src/actor/code_sync/__init__.py index 667696a6f..610253662 100644 --- a/python/monarch/_src/actor/code_sync/__init__.py +++ b/python/monarch/_src/actor/code_sync/__init__.py @@ -6,7 +6,9 @@ from monarch._rust_bindings.monarch_extension.code_sync import ( # noqa: F401 CodeSyncMeshClient, + CodeSyncMethod, RemoteWorkspace, + WorkspaceConfig, WorkspaceLocation, WorkspaceShape, ) diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 3608f915d..c3ebe514e 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -8,11 +8,11 @@ import asyncio import logging -import os import sys import threading import warnings from contextlib import AbstractContextManager +from pathlib import Path from typing import ( Any, @@ -58,7 +58,9 @@ ) from monarch._src.actor.code_sync import ( CodeSyncMeshClient, + CodeSyncMethod, RemoteWorkspace, + WorkspaceConfig, WorkspaceLocation, WorkspaceShape, ) @@ -73,6 +75,8 @@ from monarch._src.actor.endpoint import endpoint from monarch._src.actor.future import DeprecatedNotAFuture, Future from monarch._src.actor.shape import MeshTrait +from monarch.tools.config import Workspace +from monarch.tools.utils import conda as conda_utils HAS_TENSOR_ENGINE = False try: @@ -359,25 +363,53 @@ def rank_tensor(self, dim: str | Sequence[str]) -> "Tensor": def rank_tensors(self) -> Dict[str, "Tensor"]: return self._device_mesh.ranks - async def sync_workspace(self, auto_reload: bool = False) -> None: + async def sync_workspace( + self, + workspace: Workspace = None, + conda: bool = False, + auto_reload: bool = False, + ) -> None: if self._code_sync_client is None: self._code_sync_client = CodeSyncMeshClient.spawn_blocking( proc_mesh=await self._proc_mesh_for_asyncio_fixme, ) + # TODO(agallagher): We need some way to configure and pass this # in -- right now we're assuming the `gpu` dimension, which isn't # correct. # The workspace shape (i.e. only perform one rsync per host). assert set(self._shape.labels).issubset({"gpus", "hosts"}) + + workspaces = [] + if workspace is not None: + workspaces.append( + WorkspaceConfig( + local=Path(workspace), + remote=RemoteWorkspace( + location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"), + shape=WorkspaceShape.shared("gpus"), + ), + method=CodeSyncMethod.Rsync, + ), + ) + + # If `conda` is set, also sync the currently activated conda env. + conda_prefix = conda_utils.active_env_dir() + if conda and conda_prefix is not None: + workspaces.append( + WorkspaceConfig( + local=Path(conda_prefix), + remote=RemoteWorkspace( + location=WorkspaceLocation.FromEnvVar("CONDA_PREFIX"), + shape=WorkspaceShape.shared("gpus"), + ), + method=CodeSyncMethod.CondaSync, + ), + ) + assert self._code_sync_client is not None - await self._code_sync_client.sync_workspace( - # TODO(agallagher): Is there a better way to infer/set the local - # workspace dir, rather than use PWD? - local=os.getcwd(), - remote=RemoteWorkspace( - location=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"), - shape=WorkspaceShape.shared("gpus"), - ), + await self._code_sync_client.sync_workspaces( + workspaces=workspaces, auto_reload=auto_reload, ) diff --git a/python/monarch/tools/config/__init__.py b/python/monarch/tools/config/__init__.py index 8f03bca30..edce3f3da 100644 --- a/python/monarch/tools/config/__init__.py +++ b/python/monarch/tools/config/__init__.py @@ -6,7 +6,7 @@ # pyre-strict from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List from torchx.specs import Role @@ -24,6 +24,14 @@ class UnnamedAppDef: metadata: Dict[str, str] = field(default_factory=dict) +# TODO: provide a proper Workspace class to support +# - multiple workspaces +# - empty workspaces +# - no workspace +# - experimental directories +Workspace = str | None + + @dataclass class Config: """ @@ -32,6 +40,6 @@ class Config: scheduler: str = NOT_SET scheduler_args: dict[str, Any] = field(default_factory=dict) - workspace: Optional[str] = None + workspace: Workspace = None dryrun: bool = False appdef: UnnamedAppDef = field(default_factory=UnnamedAppDef) diff --git a/python/monarch/tools/config/defaults.py b/python/monarch/tools/config/defaults.py index af19385cb..39effb7fa 100644 --- a/python/monarch/tools/config/defaults.py +++ b/python/monarch/tools/config/defaults.py @@ -8,10 +8,10 @@ """Defines defaults for ``monarch.tools``""" -from typing import Callable, Optional +from typing import Callable from monarch.tools.components import hyperactor -from monarch.tools.config import Config, UnnamedAppDef +from monarch.tools.config import Config, UnnamedAppDef, Workspace from torchx import specs from torchx.schedulers import ( @@ -40,7 +40,7 @@ def scheduler_factories() -> dict[str, SchedulerFactory]: } -def config(scheduler: str, workspace: Optional[str] = None) -> Config: +def config(scheduler: str, workspace: Workspace = None) -> Config: """The default :py:class:`~monarch.tools.config.Config` to use when submitting to the provided ``scheduler``.""" return Config(scheduler=scheduler, workspace=workspace) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 20dae03bb..96c733b3b 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -35,6 +35,7 @@ local_proc_mesh, proc_mesh, ) +from monarch.tools.config import defaults from typing_extensions import assert_type @@ -748,6 +749,49 @@ async def test_same_actor_twice() -> None: ), f"Expected error message about duplicate actor name, got: {error_msg}" +class LsActor(Actor): + def __init__(self, workspace: str): + self.workspace = workspace + + @endpoint + async def ls(self) -> list[str]: + return os.listdir(self.workspace) + + +async def test_sync_workspace() -> None: + pm = await proc_mesh(gpus=1) + + # create two workspaces: one for local and one for remote + with tempfile.TemporaryDirectory() as workspace_src, tempfile.TemporaryDirectory() as workspace_dst, unittest.mock.patch.dict( + os.environ, {"WORKSPACE_DIR": workspace_dst} + ): + os.environ["WORKSPACE_DIR"] = workspace_dst + config = defaults.config("slurm", workspace_src) + await pm.sync_workspace( + workspace=config.workspace, conda=False, auto_reload=True + ) + + # now file in remote workspace initially + am = await pm.spawn("ls", LsActor, workspace_dst) + for item in list(am.ls.call().get()): + assert len(item[1]) == 0 + + # write a file to local workspace + file_path = os.path.join(workspace_src, "new_file") + with open(file_path, "w") as f: + f.write("hello world") + f.flush() + + # force a sync and it should populate on the dst workspace + await pm.sync_workspace(config.workspace, conda=False, auto_reload=True) + for item in list(am.ls.call().get()): + assert len(item[1]) == 1 + assert item[1][0] == "new_file" + file_path = os.path.join(workspace_dst, item[1][0]) + with open(file_path, "r") as f: + assert f.readline() == "hello world" + + class TestActorMeshStop(unittest.IsolatedAsyncioTestCase): async def test_actor_mesh_stop(self) -> None: pm = proc_mesh(gpus=2)