Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions tfhe/src/high_level_api/global_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ where
})
}

#[cfg(feature = "gpu")]
pub use gpu::clear_gpu_thread_locals;
#[cfg(feature = "gpu")]
pub(in crate::high_level_api) use gpu::with_thread_local_cuda_streams_for_gpu_indexes;

Expand Down Expand Up @@ -222,6 +224,21 @@ mod gpu {
.collect(),
}
}

fn clear(&mut self) {
self.custom.take();
// "Reset" the lazycells instead of emptying the vec as this allows to reuse the
// the StreamPool, the streams are going to get re-created lazily again
for (index, cell) in self.single.iter_mut().enumerate() {
let ctor =
Box::new(move || CudaStreams::new_single_gpu(GpuIndex::new(index as u32)));
*cell = LazyCell::new(ctor as Box<dyn Fn() -> CudaStreams>);
}
}
}

thread_local! {
static POOL: RefCell<CudaStreamPool> = RefCell::new(CudaStreamPool::new());
}

pub(in crate::high_level_api) fn with_thread_local_cuda_streams_for_gpu_indexes<
Expand All @@ -231,10 +248,6 @@ mod gpu {
gpu_indexes: &[GpuIndex],
func: F,
) -> R {
thread_local! {
static POOL: RefCell<CudaStreamPool> = RefCell::new(CudaStreamPool::new());
}

if gpu_indexes.len() == 1 {
POOL.with_borrow(|pool| func(&pool.single[gpu_indexes[0].get() as usize]))
} else {
Expand Down Expand Up @@ -296,6 +309,13 @@ mod gpu {
}
}
}

/// Clears all the thread_locals that store Cuda related items
/// this means keys, and other internal data, streams used
pub fn clear_gpu_thread_locals() {
unset_server_key();
POOL.with_borrow_mut(|pool| pool.clear());
}
}

#[cfg(feature = "hpu")]
Expand Down
6 changes: 3 additions & 3 deletions tfhe/src/high_level_api/integers/oprf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -827,8 +827,8 @@ mod test {
use crate::prelude::check_valid_cuda_malloc_assert_oom;
use crate::shortint::parameters::PARAM_GPU_MULTI_BIT_GROUP_4_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
use crate::{
unset_server_key, ClientKey, CompressedServerKey, FheInt128, FheUint32, FheUint64,
GpuIndex,
clear_gpu_thread_locals, ClientKey, CompressedServerKey, FheInt128, FheUint32,
FheUint64, GpuIndex,
};
use rayon::iter::IndexedParallelIterator;
use rayon::prelude::{IntoParallelRefIterator, ParallelSlice};
Expand Down Expand Up @@ -902,7 +902,7 @@ mod test {
let idx: Vec<usize> = (0..sample_count).collect();
let pool = ThreadPoolBuilder::new()
.num_threads(8 * num_gpus)
.exit_handler(|_| unset_server_key())
.exit_handler(|_| clear_gpu_thread_locals())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a way to keep this as default or we will need to remember to add it all the times we use it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we only want to have unset_server_key we should remove the extra CudaStreamPool used by decrypt, which I think is ok to remove

otherwise we should keep clear_gpu_thread_locals, its only needed when using a non local rayon pool

.build()
.unwrap();
let real_values: Vec<u64> = pool.install(|| {
Expand Down
2 changes: 2 additions & 0 deletions tfhe/src/high_level_api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ use crate::{error, Error, Versionize};
use backward_compatibility::compressed_ciphertext_list::SquashedNoiseCiphertextStateVersions;
pub use config::{Config, ConfigBuilder};
#[cfg(feature = "gpu")]
pub use global_state::clear_gpu_thread_locals;
#[cfg(feature = "gpu")]
pub use global_state::CudaGpuChoice;
#[cfg(feature = "gpu")]
pub use global_state::CustomMultiGpuIndexes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,67 @@ use crate::core_crypto::gpu::get_number_of_gpus;
use crate::high_level_api::global_state::CustomMultiGpuIndexes;
use crate::prelude::*;
use crate::{
set_server_key, unset_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device,
FheUint32, GpuIndex,
clear_gpu_thread_locals, set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device,
FheUint32, FheUint8, GpuIndex,
};
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use rayon::ThreadPoolBuilder;

/// Regression test: dropping a rayon pool whose threads holds some thread-local GPU
/// data (keys, streams, etc) can cause issue if not properly cleaned up.
///
/// This is because rayon does not seem to wait for the thread destruction
/// which then creates ordering problems with the CUDA driver
///
/// The scenario is:
/// 1. Create a custom rayon thread pool.
/// 2. On each thread, set a GPU server key (which stores CUDA resources in thread-locals).
/// 3. decrypt as decrypt init and uses a different set of cuda stream thread locals
/// 4. Drop the pool
#[test]
fn test_drop_rayon_pool_with_gpu_server_key_thread_locals() {
let config = ConfigBuilder::default().build();
let cks = ClientKey::generate(config);

let num_gpus = get_number_of_gpus() as usize;

let compressed_sks = CompressedServerKey::new(&cks);
let sks_vec: Vec<_> = (0..num_gpus)
.map(|i| compressed_sks.decompress_to_specific_gpu(GpuIndex::new(i as u32)))
.collect();

let pool = ThreadPoolBuilder::new()
.num_threads(4 * num_gpus)
.exit_handler(|_| clear_gpu_thread_locals())
.build()
.unwrap();

let results: Vec<u8> = pool.install(|| {
(0..4 * num_gpus)
.into_par_iter()
.map_init(
|| {
let gpu_index = rayon::current_thread_index().unwrap_or(0) % num_gpus;
set_server_key(sks_vec[gpu_index].clone());
},
|(), _| {
let ct = FheUint8::encrypt_trivial(42u8);
let result: u8 = ct.decrypt(&cks);
result
},
)
.collect()
});

for val in &results {
assert_eq!(*val, 42u8);
}

// Explicitly drop the pool — this is where the bug manifests:
// rayon threads are joined, their thread-locals (holding GPU server keys
// referencing CUDA resources) are dropped.
drop(pool);
}

#[test]
fn test_gpu_selection() {
Expand Down Expand Up @@ -187,6 +245,9 @@ fn test_specific_gpu_selection() {
assert_eq!(c.current_device(), Device::CudaGpu);
assert_eq!(c.gpu_indexes(), &[first_gpu]);
assert_eq!(decrypted, clear_a.wrapping_add(clear_b));
unset_server_key();
// unset_server_key is sufficient but we use clear_gpu_thread_locals
// in order to test that after calling it, the thread is still usable
// (the needed thread locals will lazily recreate themselves, nothing prevents them)
clear_gpu_thread_locals();
}
}
2 changes: 1 addition & 1 deletion tfhe/src/high_level_api/tests/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
mod cpk_re_randomization;
#[cfg(feature = "gpu")]
mod gpu_selection;
mod gpu;
mod noise_distribution;
mod noise_squashing;
mod tags_on_entities;
Expand Down
Loading