diff --git a/Cargo.toml b/Cargo.toml index ee100f2aa..a9dc8bc1d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,7 @@ serde = {version = "^1", optional = true, features = ["derive"]} serde_bytes = {version = "0.11", optional = true} thiserror = {version = "^2" } url = "2.5" +cstr8 = "0.1.4" [features] bindgen = ["hts-sys/bindgen"] diff --git a/src/bcf/header.rs b/src/bcf/header.rs index 20594b09d..ed7ae4484 100644 --- a/src/bcf/header.rs +++ b/src/bcf/header.rs @@ -40,6 +40,7 @@ use std::str; use crate::htslib; +use cstr8::CStr8; use linear_map::LinearMap; use crate::errors::{Error, Result}; @@ -398,17 +399,14 @@ impl HeaderView { } /// Convert string ID (e.g., for a `FILTER` value) to its numeric identifier. - pub fn name_to_id(&self, id: &[u8]) -> Result { - let c_str = ffi::CString::new(id).unwrap(); + pub fn name_to_id(&self, id: &CStr8) -> Result { unsafe { match htslib::bcf_hdr_id2int( self.inner, htslib::BCF_DT_ID as i32, - c_str.as_ptr() as *const c_char, + id.as_ptr() as *const c_char, ) { - -1 => Err(Error::BcfUnknownID { - id: str::from_utf8(id).unwrap().to_owned(), - }), + -1 => Err(Error::BcfUnknownID { id: id.into() }), i => Ok(Id(i as u32)), } } diff --git a/src/bcf/mod.rs b/src/bcf/mod.rs index a19183073..cb83134b9 100644 --- a/src/bcf/mod.rs +++ b/src/bcf/mod.rs @@ -835,6 +835,7 @@ fn bcf_open(target: &[u8], mode: &[u8]) -> Result<*mut htslib::htsFile> { #[cfg(test)] mod tests { + use cstr8::cstr8; use tempfile::NamedTempFile; use super::record::Buffer; @@ -1222,8 +1223,8 @@ mod tests { use crate::bcf::header::Id; assert_eq!(header.id_to_name(Id(4)), b"GT"); - assert_eq!(header.name_to_id(b"GT").unwrap(), Id(4)); - assert!(header.name_to_id(b"XX").is_err()); + assert_eq!(header.name_to_id(cstr8!("GT")).unwrap(), Id(4)); + assert!(header.name_to_id(cstr8!("XX")).is_err()); } #[test] @@ -1427,10 +1428,12 @@ mod tests { record.set_qual(10.0); - record.push_info_integer(b"N1", &[32]).unwrap(); - record.push_info_float(b"F1", &[33.0]).unwrap(); - record.push_info_string(b"S1", &[b"fourtytwo"]).unwrap(); - record.push_info_flag(b"X1").unwrap(); + record.push_info_integer(cstr8!("N1"), &[32]).unwrap(); + record.push_info_float(cstr8!("F1"), &[33.0]).unwrap(); + record + .push_info_string(cstr8!("S1"), &[b"fourtytwo"]) + .unwrap(); + record.push_info_flag(cstr8!("X1")).unwrap(); record .push_genotypes(&[ @@ -1442,12 +1445,16 @@ mod tests { .unwrap(); record - .push_format_string(b"FS1", &[&b"yes"[..], &b"no"[..]]) + .push_format_string(cstr8!("FS1"), &[&b"yes"[..], &b"no"[..]]) + .unwrap(); + record + .push_format_integer(cstr8!("FF1"), &[43, 11]) + .unwrap(); + record + .push_format_float(cstr8!("FN1"), &[42.0, 10.0]) .unwrap(); - record.push_format_integer(b"FF1", &[43, 11]).unwrap(); - record.push_format_float(b"FN1", &[42.0, 10.0]).unwrap(); record - .push_format_char(b"CH1", &[b"A"[0], b"B"[0]]) + .push_format_char(cstr8!("CH1"), &[b"A"[0], b"B"[0]]) .unwrap(); // Finally, write out the record. diff --git a/src/bcf/record.rs b/src/bcf/record.rs index 8cd1edd4b..c82b70e58 100644 --- a/src/bcf/record.rs +++ b/src/bcf/record.rs @@ -15,6 +15,7 @@ use std::str; use std::{ffi, iter}; use bio_types::genome; +use cstr8::{cstr8, CStr8, CString8}; use derive_new::new; use ieee754::Ieee754; use lazy_static::lazy_static; @@ -86,13 +87,25 @@ pub trait FilterId { impl FilterId for [u8] { fn id_from_header(&self, header: &HeaderView) -> Result { - header.name_to_id(self) + let str = String::from_utf8(self.to_vec()).map_err(|_| Error::BcfInvalidRecord)?; + let id = CString8::new(str).map_err(|_| Error::BcfInvalidRecord)?; + header.name_to_id(&id) } fn is_pass(&self) -> bool { matches!(self, b"PASS" | b".") } } +impl<'a> FilterId for &'a CStr8 { + fn id_from_header(&self, header: &HeaderView) -> Result { + header.name_to_id(self) + } + + fn is_pass(&self) -> bool { + matches!(self.as_bytes(), b"PASS" | b".") + } +} + impl FilterId for Id { fn id_from_header(&self, _header: &HeaderView) -> Result { Ok(*self) @@ -421,6 +434,7 @@ impl Record { /// /// # Example /// ```rust + /// # use cstr8::cstr8; /// # use rust_htslib::bcf::{Format, Header, Writer}; /// # use rust_htslib::bcf::header::Id; /// # use tempfile::NamedTempFile; @@ -431,8 +445,8 @@ impl Record { /// header.push_record(br#"##FILTER="#); /// # let vcf = Writer::from_path(path, &header, true, Format::Vcf).unwrap(); /// # let mut record = vcf.empty_record(); - /// let foo = record.header().name_to_id(b"foo").unwrap(); - /// let bar = record.header().name_to_id(b"bar").unwrap(); + /// let foo = record.header().name_to_id(cstr8!("foo")).unwrap(); + /// let bar = record.header().name_to_id(cstr8!("bar")).unwrap(); /// assert!(record.has_filter("PASS".as_bytes())); /// let mut filters = vec![&foo, &bar]; /// record.set_filters(&filters).unwrap(); @@ -473,6 +487,7 @@ impl Record { /// /// # Example /// ```rust + /// # use cstr8::cstr8; /// # use rust_htslib::bcf::{Format, Header, Writer}; /// # use tempfile::NamedTempFile; /// # let tmp = tempfile::NamedTempFile::new().unwrap(); @@ -483,7 +498,7 @@ impl Record { /// # let vcf = Writer::from_path(path, &header, true, Format::Vcf).unwrap(); /// # let mut record = vcf.empty_record(); /// let foo = "foo".as_bytes(); - /// let bar = record.header().name_to_id(b"bar").unwrap(); + /// let bar = record.header().name_to_id(cstr8!("bar")).unwrap(); /// assert!(record.has_filter("PASS".as_bytes())); /// /// record.push_filter(foo).unwrap(); @@ -689,7 +704,7 @@ impl Record { /// ``` pub fn push_genotypes(&mut self, genotypes: &[GenotypeAllele]) -> Result<()> { let encoded: Vec = genotypes.iter().map(|gt| i32::from(*gt)).collect(); - self.push_format_integer(b"GT", &encoded) + self.push_format_integer(cstr8!("GT"), &encoded) } /// Add/replace genotypes in FORMAT GT tag by providing a list of genotypes. @@ -759,7 +774,7 @@ impl Record { )), ); } - self.push_format_integer(b"GT", &data) + self.push_format_integer(cstr8!("GT"), &data) } /// Get genotypes as vector of one `Genotype` per sample. @@ -798,6 +813,7 @@ impl Record { /// for an example of the setup used here.* /// /// ```rust + /// # use cstr8::cstr8; /// # use rust_htslib::bcf::{Format, Writer}; /// # use rust_htslib::bcf::header::Header; /// # @@ -808,7 +824,7 @@ impl Record { /// # // Write uncompressed VCF to stdout with above header and get an empty record /// # let mut vcf = Writer::from_stdout(&header, true, Format::Vcf).unwrap(); /// # let mut record = vcf.empty_record(); - /// record.push_format_integer(b"DP", &[20, 12]).expect("Failed to set DP format field"); + /// record.push_format_integer(cstr8!("DP"), &[20, 12]).expect("Failed to set DP format field"); /// /// let read_depths = record.format(b"DP").integer().expect("Couldn't retrieve DP field"); /// let sample1_depth = read_depths[0]; @@ -846,7 +862,7 @@ impl Record { /// # Errors /// /// Returns error if tag is not present in header. - pub fn push_format_integer(&mut self, tag: &[u8], data: &[i32]) -> Result<()> { + pub fn push_format_integer(&mut self, tag: &CStr8, data: &[i32]) -> Result<()> { self.push_format(tag, data, htslib::BCF_HT_INT) } @@ -869,6 +885,7 @@ impl Record { /// VCF, header, and record. /// /// ``` + /// # use cstr8::cstr8; /// # use rust_htslib::bcf::{Format, Writer}; /// # use rust_htslib::bcf::header::Header; /// # use rust_htslib::bcf::record::GenotypeAllele; @@ -880,10 +897,10 @@ impl Record { /// # header.push_sample("test_sample".as_bytes()); /// # let mut vcf = Writer::from_stdout(&header, true, Format::Vcf).unwrap(); /// # let mut record = vcf.empty_record(); - /// record.push_format_float(b"AF", &[0.5]); + /// record.push_format_float(cstr8!("AF"), &[0.5]); /// assert_eq!(0.5, record.format(b"AF").float().unwrap()[0][0]); /// ``` - pub fn push_format_float(&mut self, tag: &[u8], data: &[f32]) -> Result<()> { + pub fn push_format_float(&mut self, tag: &CStr8, data: &[f32]) -> Result<()> { self.push_format(tag, data, htslib::BCF_HT_REAL) } @@ -898,19 +915,18 @@ impl Record { /// # Errors /// /// Returns error if tag is not present in header. - pub fn push_format_char(&mut self, tag: &[u8], data: &[u8]) -> Result<()> { + pub fn push_format_char(&mut self, tag: &CStr8, data: &[u8]) -> Result<()> { self.push_format(tag, data, htslib::BCF_HT_STR) } /// Add a format tag. Data is a flattened two-dimensional array. /// The first dimension contains one array for each sample. - fn push_format(&mut self, tag: &[u8], data: &[T], ht: u32) -> Result<()> { - let tag_c_str = ffi::CString::new(tag).unwrap(); + fn push_format(&mut self, tag: &CStr8, data: &[T], ht: u32) -> Result<()> { unsafe { if htslib::bcf_update_format( self.header().inner, self.inner, - tag_c_str.as_ptr() as *mut c_char, + tag.as_ptr() as *mut c_char, data.as_ptr() as *const ::std::os::raw::c_void, data.len() as i32, ht as i32, @@ -918,9 +934,7 @@ impl Record { { Ok(()) } else { - Err(Error::BcfSetTag { - tag: str::from_utf8(tag).unwrap().to_owned(), - }) + Err(Error::BcfSetTag { tag: tag.into() }) } } } @@ -939,7 +953,7 @@ impl Record { /// # Errors /// /// Returns error if tag is not present in header. - pub fn push_format_string>(&mut self, tag: &[u8], data: &[D]) -> Result<()> { + pub fn push_format_string>(&mut self, tag: &CStr8, data: &[D]) -> Result<()> { assert!( !data.is_empty(), "given string data must have at least 1 element" @@ -952,42 +966,39 @@ impl Record { .iter() .map(|s| s.as_ptr() as *mut i8) .collect::>(); - let tag_c_str = ffi::CString::new(tag).unwrap(); unsafe { if htslib::bcf_update_format_string( self.header().inner, self.inner, - tag_c_str.as_ptr() as *mut c_char, + tag.as_ptr() as *mut c_char, c_ptrs.as_slice().as_ptr() as *mut *const c_char, data.len() as i32, ) == 0 { Ok(()) } else { - Err(Error::BcfSetTag { - tag: str::from_utf8(tag).unwrap().to_owned(), - }) + Err(Error::BcfSetTag { tag: tag.into() }) } } } /// Add/replace an integer-typed INFO entry. - pub fn push_info_integer(&mut self, tag: &[u8], data: &[i32]) -> Result<()> { + pub fn push_info_integer(&mut self, tag: &CStr8, data: &[i32]) -> Result<()> { self.push_info(tag, data, htslib::BCF_HT_INT) } /// Remove the integer-typed INFO entry. - pub fn clear_info_integer(&mut self, tag: &[u8]) -> Result<()> { + pub fn clear_info_integer(&mut self, tag: &CStr8) -> Result<()> { self.push_info::(tag, &[], htslib::BCF_HT_INT) } /// Add/replace a float-typed INFO entry. - pub fn push_info_float(&mut self, tag: &[u8], data: &[f32]) -> Result<()> { + pub fn push_info_float(&mut self, tag: &CStr8, data: &[f32]) -> Result<()> { self.push_info(tag, data, htslib::BCF_HT_REAL) } /// Remove the float-typed INFO entry. - pub fn clear_info_float(&mut self, tag: &[u8]) -> Result<()> { + pub fn clear_info_float(&mut self, tag: &CStr8) -> Result<()> { self.push_info::(tag, &[], htslib::BCF_HT_REAL) } @@ -997,13 +1008,12 @@ impl Record { /// * `tag` - the tag to add/replace /// * `data` - the data to set /// * `ht` - the HTSLib type to use - fn push_info(&mut self, tag: &[u8], data: &[T], ht: u32) -> Result<()> { - let tag_c_str = ffi::CString::new(tag).unwrap(); + fn push_info(&mut self, tag: &CStr8, data: &[T], ht: u32) -> Result<()> { unsafe { if htslib::bcf_update_info( self.header().inner, self.inner, - tag_c_str.as_ptr() as *mut c_char, + tag.as_ptr() as *mut c_char, data.as_ptr() as *const ::std::os::raw::c_void, data.len() as i32, ht as i32, @@ -1011,36 +1021,77 @@ impl Record { { Ok(()) } else { - Err(Error::BcfSetTag { - tag: str::from_utf8(tag).unwrap().to_owned(), - }) + Err(Error::BcfSetTag { tag: tag.into() }) } } } /// Set flag into the INFO column. - pub fn push_info_flag(&mut self, tag: &[u8]) -> Result<()> { + pub fn push_info_flag(&mut self, tag: &CStr8) -> Result<()> { self.push_info_string_impl(tag, &[b""], htslib::BCF_HT_FLAG) } /// Remove the flag from the INFO column. - pub fn clear_info_flag(&mut self, tag: &[u8]) -> Result<()> { + pub fn clear_info_flag(&mut self, tag: &CStr8) -> Result<()> { self.push_info_string_impl(tag, &[], htslib::BCF_HT_FLAG) } /// Add/replace a string-typed INFO entry. - pub fn push_info_string(&mut self, tag: &[u8], data: &[&[u8]]) -> Result<()> { + pub fn push_info_string(&mut self, tag: &CStr8, data: &[&[u8]]) -> Result<()> { self.push_info_string_impl(tag, data, htslib::BCF_HT_STR) } /// Remove the string field from the INFO column. - pub fn clear_info_string(&mut self, tag: &[u8]) -> Result<()> { + pub fn clear_info_string(&mut self, tag: &CStr8) -> Result<()> { self.push_info_string_impl(tag, &[], htslib::BCF_HT_STR) } /// Add an string-valued INFO tag. - fn push_info_string_impl(&mut self, tag: &[u8], data: &[&[u8]], ht: u32) -> Result<()> { - let mut buf: Vec = Vec::new(); + fn push_info_string_impl(&mut self, tag: &CStr8, data: &[&[u8]], ht: u32) -> Result<()> { + if data.is_empty() { + // Clear the tag + let c_str = unsafe { CStr8::from_utf8_with_nul_unchecked(b"\0") }; + let len = 0; + unsafe { + return if htslib::bcf_update_info( + self.header().inner, + self.inner, + tag.as_ptr() as *mut c_char, + c_str.as_ptr() as *const ::std::os::raw::c_void, + len as i32, + ht as i32, + ) == 0 + { + Ok(()) + } else { + Err(Error::BcfSetTag { tag: tag.into() }) + }; + } + } + + if data == &[b""] { + // This is a flag + let c_str = unsafe { CStr8::from_utf8_with_nul_unchecked(b"\0") }; + let len = 1; + unsafe { + return if htslib::bcf_update_info( + self.header().inner, + self.inner, + tag.as_ptr() as *mut c_char, + c_str.as_ptr() as *const ::std::os::raw::c_void, + len as i32, + ht as i32, + ) == 0 + { + Ok(()) + } else { + Err(Error::BcfSetTag { tag: tag.into() }) + }; + } + } + + let data_bytes = data.iter().map(|x| x.len() + 2).sum(); // estimate for buffer pre-alloc + let mut buf: Vec = Vec::with_capacity(data_bytes); for (i, &s) in data.iter().enumerate() { if i > 0 { buf.extend(b","); @@ -1053,12 +1104,11 @@ impl Record { } else { c_str.to_bytes().len() }; - let tag_c_str = ffi::CString::new(tag).unwrap(); unsafe { if htslib::bcf_update_info( self.header().inner, self.inner, - tag_c_str.as_ptr() as *mut c_char, + tag.as_ptr() as *mut c_char, c_str.as_ptr() as *const ::std::os::raw::c_void, len as i32, ht as i32, @@ -1066,9 +1116,7 @@ impl Record { { Ok(()) } else { - Err(Error::BcfSetTag { - tag: str::from_utf8(tag).unwrap().to_owned(), - }) + Err(Error::BcfSetTag { tag: tag.into() }) } } } @@ -1769,7 +1817,7 @@ mod tests { let mut record = vcf.empty_record(); assert!(record.has_filter("PASS".as_bytes())); record.push_filter("foo".as_bytes()).unwrap(); - let bar = record.header().name_to_id(b"bar").unwrap(); + let bar = record.header().name_to_id(cstr8!("bar")).unwrap(); record.push_filter(&bar).unwrap(); assert!(record.has_filter("foo".as_bytes())); assert!(record.has_filter(&bar)); @@ -1811,8 +1859,8 @@ mod tests { header.push_record(br#"##FILTER="#); let vcf = Writer::from_path(path, &header, true, Format::Vcf).unwrap(); let mut record = vcf.empty_record(); - let foo = record.header().name_to_id(b"foo").unwrap(); - let bar = record.header().name_to_id(b"bar").unwrap(); + let foo = record.header().name_to_id(cstr8!("foo")).unwrap(); + let bar = record.header().name_to_id(cstr8!("bar")).unwrap(); record.set_filters(&[&foo, &bar]).unwrap(); assert!(record.has_filter(&foo)); assert!(record.has_filter(&bar));