Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/cublas/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
//! Wrappers around the [cublas API](https://docs.nvidia.com/cuda/cublas/index.html),
//! in three levels. See crate documentation for description of each.
//! [CudaBlas] wraps around the [cublas API](https://docs.nvidia.com/cuda/cublas/index.html).
//!
//! To use:
//!
//! 1. Instantiate a [CudaBlas] handle with [CudaBlas::new()]
//! 2. Choose your operation: [Gemm], [Gemv], and [Asum] traits, which [CudaBlas] implements.
//! 3. f16/bf16/f32/f64 are all supported at the trait level.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This statement is incorrect. While the Gemm trait has implementations for f16, bf16, f32, and f64, the Gemv and Asum traits only support f32 and f64. This documentation is misleading to users of the library.

Please update the documentation to accurately reflect the supported types for each trait.

//! 3. Supported types: [Gemm] supports `f16`/`bf16`/`f32`/`f64`; [Gemv] & [Asum] support `f32`/`f64`.

//! 4. Instantiate your corresponding config: [GemmConfig], [StridedBatchedConfig], [GemvConfig], [AsumConfig]
//! 5. Call using [CudaBlas::gemm()], [CudaBlas::gemv()], or [CudaBlas::asum()]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This is slightly misleading. The methods like gemm are part of the Gemm trait, not inherent methods of CudaBlas. As a result, the rustdoc link [CudaBlas::gemm()] will be broken.

It would be clearer to show an example of a method call on an instance, e.g., blas.gemm(...), and to hint that the trait needs to be in scope for the method to be available.

//! 5. Call the trait's method on the handle, e.g. `blas.gemm(...)`.

//!
//! Note that all above apis work with [crate::driver::DevicePtr]/[crate::driver::DevicePtrMut], so they
//! accept [crate::driver::CudaSlice], [crate::driver::CudaView], and [crate::driver::CudaViewMut].

pub mod result;
pub mod safe;
Expand Down
8 changes: 8 additions & 0 deletions src/cublaslt/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
//! [CudaBlasLT] wraps around [cuBLASLt](https://docs.nvidia.com/cuda/cublas/index.html#using-the-cublaslt-api) via:
//!
//! 1. Instantiate a [CudaBlasLT] handle with [CudaBlasLT::new()]
//! 2. Execute a gemm using [CudaBlasLT::matmul()]
//!
//! Note that all above apis work with [crate::driver::DevicePtr]/[crate::driver::DevicePtrMut], so they
//! accept [crate::driver::CudaSlice], [crate::driver::CudaView], and [crate::driver::CudaViewMut].

pub mod result;
pub mod safe;
#[allow(warnings)]
Expand Down
9 changes: 9 additions & 0 deletions src/cufile/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
//! [Cufile] wraps around [cuFILE](https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html) via:
//!
//! 1. Instantiate a new handle to the api with [Cufile::new()]
//! 2. Register a file with [Cufile::register()], this accepts a [std::fs::File].
//! 3. Read/write from filesystem using [FileHandle::sync_read], [FileHandle::sync_write], [crate::driver::CudaStream::memcpy_dtof()], [crate::driver::CudaStream::memcpy_ftod()].
//!
//! Note that all safe apis work with [crate::driver::DevicePtr] and [crate::driver::DevicePtrMut], meaning they accept both
//! [crate::driver::CudaSlice] and [crate::driver::CudaView]/[crate::driver::CudaViewMut].

pub mod result;
pub mod safe;
#[allow(warnings)]
Expand Down
9 changes: 7 additions & 2 deletions src/curand/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
//! Wrappers around the [cuRAND API](https://docs.nvidia.com/cuda/curand/index.html)
//! in three levels. See crate documentation for description of each.
//! [CudaRng] safe bindings around [cuRAND](https://docs.nvidia.com/cuda/curand/index.html).
//!
//! Instantiate with [CudaRng::new()], and then fill existing [crate::driver::CudaSlice]/[crate::driver::CudaViewMut]
//! with three different
//! 1. Uniform - [CudaRng::fill_with_uniform()]
//! 2. Normal - [CudaRng::fill_with_normal()]
//! 3. LogNormal - [CudaRng::fill_with_log_normal()]

pub mod result;
pub mod safe;
Expand Down
12 changes: 10 additions & 2 deletions src/nccl/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
//! Wrappers around the [NCCL API](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html)
//! in three levels. See crate documentation for description of each.
//! [Comm] wraps around the [NCCL API](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/index.html), via:
//!
//! 1. Instantiate with [Comm::from_devices()] or [Comm::from_rank()]
//! 2. Peer to peer with [Comm::send()]/[Comm::recv()]
//! 3. Broadcast [Comm::broadcast()]/[Comm::broadcast_in_place()]
//! 4. Reduce: [Comm::reduce()]/[Comm::reduce_in_place()]
//! 5. Gather & Reduce [Comm::all_gather()]/[Comm::all_reduce()]/[Comm::all_reduce_in_place()]
//!
//! Note that all above apis work with [crate::driver::DevicePtr]/[crate::driver::DevicePtrMut], so they
//! accept [crate::driver::CudaSlice], [crate::driver::CudaView], and [crate::driver::CudaViewMut].

pub mod result;
pub mod safe;
Expand Down
2 changes: 2 additions & 0 deletions src/nccl/safe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ impl Comm {
}

impl Comm {
/// Send data to one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclsend)
pub fn send<S: DevicePtr<T>, T: NcclType>(
&self,
data: &S,
Expand All @@ -229,6 +230,7 @@ impl Comm {
Ok(())
}

/// Receive data from one peer, see [cuda docs](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/p2p.html#ncclrecv)
pub fn recv<R: DevicePtrMut<T>, T: NcclType>(
&self,
buff: &mut R,
Expand Down
Loading