Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 1 addition & 141 deletions src/common/datasource/src/buffered_writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::future::Future;

use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::parquet::format::FileMetaData;
use snafu::{OptionExt, ResultExt};
use tokio::io::{AsyncWrite, AsyncWriteExt};

use crate::error::{self, Result};
use crate::share_buffer::SharedBuffer;

pub struct LazyBufferedWriter<T, U, F> {
path: String,
writer_factory: F,
writer: Option<T>,
/// None stands for [`LazyBufferedWriter`] closed.
encoder: Option<U>,
buffer: SharedBuffer,
rows_written: usize,
bytes_written: u64,
threshold: usize,
}
use crate::error::Result;

pub trait DfRecordBatchEncoder {
fn write(&mut self, batch: &RecordBatch) -> Result<()>;
Expand All @@ -43,126 +26,3 @@ pub trait DfRecordBatchEncoder {
pub trait ArrowWriterCloser {
async fn close(mut self) -> Result<FileMetaData>;
}

impl<
T: AsyncWrite + Send + Unpin,
U: DfRecordBatchEncoder + ArrowWriterCloser,
F: Fn(String) -> Fut,
Fut: Future<Output = Result<T>>,
> LazyBufferedWriter<T, U, F>
{
/// Closes `LazyBufferedWriter` and optionally flushes all data to underlying storage
/// if any row's been written.
pub async fn close_with_arrow_writer(mut self) -> Result<(FileMetaData, u64)> {
let encoder = self
.encoder
.take()
.context(error::BufferedWriterClosedSnafu)?;
let metadata = encoder.close().await?;

// It's important to shut down! flushes all pending writes
self.close_inner_writer().await?;
Ok((metadata, self.bytes_written))
}
}

impl<
T: AsyncWrite + Send + Unpin,
U: DfRecordBatchEncoder,
F: Fn(String) -> Fut,
Fut: Future<Output = Result<T>>,
> LazyBufferedWriter<T, U, F>
{
/// Closes the writer and flushes the buffer data.
pub async fn close_inner_writer(&mut self) -> Result<()> {
// Use `rows_written` to keep a track of if any rows have been written.
// If no row's been written, then we can simply close the underlying
// writer without flush so that no file will be actually created.
if self.rows_written != 0 {
self.bytes_written += self.try_flush(true).await?;
}

if let Some(writer) = &mut self.writer {
writer.shutdown().await.context(error::AsyncWriteSnafu)?;
}
Ok(())
}

pub fn new(
threshold: usize,
buffer: SharedBuffer,
encoder: U,
path: impl AsRef<str>,
writer_factory: F,
) -> Self {
Self {
path: path.as_ref().to_string(),
threshold,
encoder: Some(encoder),
buffer,
rows_written: 0,
bytes_written: 0,
writer_factory,
writer: None,
}
}

pub async fn write(&mut self, batch: &RecordBatch) -> Result<()> {
let encoder = self
.encoder
.as_mut()
.context(error::BufferedWriterClosedSnafu)?;
encoder.write(batch)?;
self.rows_written += batch.num_rows();
self.bytes_written += self.try_flush(false).await?;
Ok(())
}

async fn try_flush(&mut self, all: bool) -> Result<u64> {
let mut bytes_written: u64 = 0;

// Once buffered data size reaches threshold, split the data in chunks (typically 4MB)
// and write to underlying storage.
while self.buffer.buffer.lock().unwrap().len() >= self.threshold {
let chunk = {
let mut buffer = self.buffer.buffer.lock().unwrap();
buffer.split_to(self.threshold)
};
let size = chunk.len();

self.maybe_init_writer()
.await?
.write_all(&chunk)
.await
.context(error::AsyncWriteSnafu)?;

bytes_written += size as u64;
}

if all {
bytes_written += self.try_flush_all().await?;
}
Ok(bytes_written)
}

/// Only initiates underlying file writer when rows have been written.
async fn maybe_init_writer(&mut self) -> Result<&mut T> {
if let Some(ref mut writer) = self.writer {
Ok(writer)
} else {
let writer = (self.writer_factory)(self.path.clone()).await?;
Ok(self.writer.insert(writer))
}
}

async fn try_flush_all(&mut self) -> Result<u64> {
let remain = self.buffer.buffer.lock().unwrap().split();
let size = remain.len();
self.maybe_init_writer()
.await?
.write_all(&remain)
.await
.context(error::AsyncWriteSnafu)?;
Ok(size as u64)
}
}
202 changes: 202 additions & 0 deletions src/common/datasource/src/compressed_writer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
// Copyright 2023 Greptime Team
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

use async_compression::tokio::write::{BzEncoder, GzipEncoder, XzEncoder, ZstdEncoder};
use snafu::ResultExt;
use tokio::io::{AsyncWrite, AsyncWriteExt};

use crate::compression::CompressionType;
use crate::error::{self, Result};

/// A compressed writer that wraps an underlying async writer with compression.
///
/// This writer supports multiple compression formats including GZIP, BZIP2, XZ, and ZSTD.
/// It provides transparent compression for any async writer implementation.
pub struct CompressedWriter {
inner: Box<dyn AsyncWrite + Unpin + Send>,
compression_type: CompressionType,
}

impl CompressedWriter {
/// Creates a new compressed writer with the specified compression type.
///
/// # Arguments
///
/// * `writer` - The underlying writer to wrap with compression
/// * `compression_type` - The type of compression to apply
pub fn new(
writer: impl AsyncWrite + Unpin + Send + 'static,
compression_type: CompressionType,
) -> Self {
let inner: Box<dyn AsyncWrite + Unpin + Send> = match compression_type {
CompressionType::Gzip => Box::new(GzipEncoder::new(writer)),
CompressionType::Bzip2 => Box::new(BzEncoder::new(writer)),
CompressionType::Xz => Box::new(XzEncoder::new(writer)),
CompressionType::Zstd => Box::new(ZstdEncoder::new(writer)),
CompressionType::Uncompressed => Box::new(writer),
};

Self {
inner,
compression_type,
}
}

/// Returns the compression type used by this writer.
pub fn compression_type(&self) -> CompressionType {
self.compression_type
}

/// Flush the writer and shutdown compression
pub async fn shutdown(mut self) -> Result<()> {
self.inner
.shutdown()
.await
.context(error::AsyncWriteSnafu)?;
Ok(())
}
}

impl AsyncWrite for CompressedWriter {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}

/// A trait for converting async writers into compressed writers.
///
/// This trait is automatically implemented for all types that implement [`AsyncWrite`].
pub trait IntoCompressedWriter {
/// Converts this writer into a [`CompressedWriter`] with the specified compression type.
///
/// # Arguments
///
/// * `self` - The underlying writer to wrap with compression
/// * `compression_type` - The type of compression to apply
fn into_compressed_writer(self, compression_type: CompressionType) -> CompressedWriter
where
Self: AsyncWrite + Unpin + Send + 'static + Sized,
{
CompressedWriter::new(self, compression_type)
}
}

impl<W: AsyncWrite + Unpin + Send + 'static> IntoCompressedWriter for W {}

#[cfg(test)]
mod tests {
use tokio::io::{AsyncReadExt, AsyncWriteExt, duplex};

use super::*;

#[tokio::test]
async fn test_compressed_writer_gzip() {
let (duplex_writer, mut duplex_reader) = duplex(1024);
let mut writer = duplex_writer.into_compressed_writer(CompressionType::Gzip);
let original = b"test data for gzip compression";

writer.write_all(original).await.unwrap();
writer.shutdown().await.unwrap();

let mut buffer = Vec::new();
duplex_reader.read_to_end(&mut buffer).await.unwrap();

// The compressed data should be different from the original
assert_ne!(buffer, original);
assert!(!buffer.is_empty());
}

#[tokio::test]
async fn test_compressed_writer_bzip2() {
let (duplex_writer, mut duplex_reader) = duplex(1024);
let mut writer = duplex_writer.into_compressed_writer(CompressionType::Bzip2);
let original = b"test data for bzip2 compression";

writer.write_all(original).await.unwrap();
writer.shutdown().await.unwrap();

let mut buffer = Vec::new();
duplex_reader.read_to_end(&mut buffer).await.unwrap();

// The compressed data should be different from the original
assert_ne!(buffer, original);
assert!(!buffer.is_empty());
}

#[tokio::test]
async fn test_compressed_writer_xz() {
let (duplex_writer, mut duplex_reader) = duplex(1024);
let mut writer = duplex_writer.into_compressed_writer(CompressionType::Xz);
let original = b"test data for xz compression";

writer.write_all(original).await.unwrap();
writer.shutdown().await.unwrap();

let mut buffer = Vec::new();
duplex_reader.read_to_end(&mut buffer).await.unwrap();

// The compressed data should be different from the original
assert_ne!(buffer, original);
assert!(!buffer.is_empty());
}

#[tokio::test]
async fn test_compressed_writer_zstd() {
let (duplex_writer, mut duplex_reader) = duplex(1024);
let mut writer = duplex_writer.into_compressed_writer(CompressionType::Zstd);
let original = b"test data for zstd compression";

writer.write_all(original).await.unwrap();
writer.shutdown().await.unwrap();

let mut buffer = Vec::new();
duplex_reader.read_to_end(&mut buffer).await.unwrap();

// The compressed data should be different from the original
assert_ne!(buffer, original);
assert!(!buffer.is_empty());
}

#[tokio::test]
async fn test_compressed_writer_uncompressed() {
let (duplex_writer, mut duplex_reader) = duplex(1024);
let mut writer = duplex_writer.into_compressed_writer(CompressionType::Uncompressed);
let original = b"test data for uncompressed";

writer.write_all(original).await.unwrap();
writer.shutdown().await.unwrap();

let mut buffer = Vec::new();
duplex_reader.read_to_end(&mut buffer).await.unwrap();

// Uncompressed data should be the same as the original
assert_eq!(buffer, original);
}
}
16 changes: 9 additions & 7 deletions src/common/datasource/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,12 +194,6 @@ pub enum Error {
location: Location,
},

#[snafu(display("Buffered writer closed"))]
BufferedWriterClosed {
#[snafu(implicit)]
location: Location,
},

#[snafu(display("Failed to write parquet file, path: {}", path))]
WriteParquet {
path: String,
Expand All @@ -208,6 +202,14 @@ pub enum Error {
#[snafu(source)]
error: parquet::errors::ParquetError,
},

#[snafu(display("Failed to build file stream"))]
BuildFileStream {
#[snafu(implicit)]
location: Location,
#[snafu(source)]
error: datafusion::error::DataFusionError,
},
}

pub type Result<T> = std::result::Result<T, Error>;
Expand Down Expand Up @@ -239,7 +241,7 @@ impl ErrorExt for Error {
| ReadRecordBatch { .. }
| WriteRecordBatch { .. }
| EncodeRecordBatch { .. }
| BufferedWriterClosed { .. }
| BuildFileStream { .. }
| OrcReader { .. } => StatusCode::Unexpected,
}
}
Expand Down
Loading