diff --git a/rustfst/src/algorithms/compose/add_on.rs b/rustfst/src/algorithms/compose/add_on.rs index c02ab05e2..9d72409e0 100644 --- a/rustfst/src/algorithms/compose/add_on.rs +++ b/rustfst/src/algorithms/compose/add_on.rs @@ -1,24 +1,38 @@ use std::fmt::Debug; +use std::io::Write; +use std::marker::PhantomData; use std::sync::Arc; use anyhow::Result; +use nom::combinator::verify; +use nom::IResult; +use crate::{NomCustomError, StateId, SymbolTable, Tr}; use crate::fst_properties::FstProperties; -use crate::fst_traits::{CoreFst, ExpandedFst, Fst, FstIntoIterator, FstIterator, StateIterator}; +use crate::fst_properties::properties::EXPANDED; +use crate::fst_traits::{CoreFst, ExpandedFst, Fst, FstIntoIterator, FstIterator, SerializableFst, StateIterator}; +use crate::parsers::{parse_bin_bool, parse_bin_i32, write_bin_bool, write_bin_i32}; +use crate::parsers::bin_fst::fst_header::{FST_MAGIC_NUMBER, FstFlags, FstHeader, OpenFstString}; +use crate::prelude::{SerializableSemiring, SerializeBinary}; use crate::semirings::Semiring; -use crate::{StateId, SymbolTable}; /// Adds an object of type T to an FST. /// The resulting type is a new FST implementation. #[derive(Debug, PartialEq, Clone)] -pub struct FstAddOn { +pub struct FstAddOn + where + W: Semiring, + F: Fst +{ pub(crate) fst: F, pub(crate) add_on: T, + w: PhantomData, + fst_type: String } -impl FstAddOn { - pub fn new(fst: F, add_on: T) -> Self { - Self { fst, add_on } +impl, T> FstAddOn { + pub fn new(fst: F, add_on: T, fst_type: String) -> Self { + Self { fst, add_on, w: PhantomData, fst_type } } pub fn fst(&self) -> &F { @@ -34,7 +48,7 @@ impl FstAddOn { } } -impl, T> CoreFst for FstAddOn { +impl, T> CoreFst for FstAddOn { type TRS = F::TRS; fn start(&self) -> Option { @@ -78,7 +92,7 @@ impl, T> CoreFst for FstAddOn { } } -impl<'a, F: StateIterator<'a>, T> StateIterator<'a> for FstAddOn { +impl<'a, W: Semiring, F: Fst, T> StateIterator<'a> for FstAddOn { type Iter = >::Iter; fn states_iter(&'a self) -> Self::Iter { @@ -86,19 +100,19 @@ impl<'a, F: StateIterator<'a>, T> StateIterator<'a> for FstAddOn { } } -impl<'a, W, F, T> FstIterator<'a, W> for FstAddOn +impl<'a, W, F, T> FstIterator<'a, W> for FstAddOn where W: Semiring + 'a, - F: FstIterator<'a, W>, + F: Fst, { - type FstIter = F::FstIter; + type FstIter = >::FstIter; fn fst_iter(&'a self) -> Self::FstIter { self.fst.fst_iter() } } -impl Fst for FstAddOn +impl Fst for FstAddOn where W: Semiring, F: Fst, @@ -128,7 +142,7 @@ where } } -impl ExpandedFst for FstAddOn +impl ExpandedFst for FstAddOn where W: Semiring, F: ExpandedFst, @@ -139,16 +153,95 @@ where } } -impl FstIntoIterator for FstAddOn +impl FstIntoIterator for FstAddOn where W: Semiring, - F: FstIntoIterator, + F: FstIntoIterator + Fst , T: Debug, { type TrsIter = F::TrsIter; - type FstIter = F::FstIter; + type FstIter = >::FstIter; fn fst_into_iter(self) -> Self::FstIter { self.fst.fst_into_iter() } } + +static ADD_ON_MAGIC_NUMBER: i32 = 446681434; +static ADD_ON_MIN_FILE_VERSION: i32 = 1; +static ADD_ON_FILE_VERSION: i32 = 1; + +impl SerializeBinary for FstAddOn>, Option>)> +where + W: SerializableSemiring, + F: SerializableFst, + AO1: SerializeBinary + Debug + Clone + PartialEq, + AO2: SerializeBinary + Debug + Clone + PartialEq, +{ + fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> { + + let (i, hdr) = FstHeader::parse( + i, + ADD_ON_MIN_FILE_VERSION, + Option::<&str>::None, + Tr::::tr_type(), + )?; + + let (i, _) = verify(parse_bin_i32, |v: &i32| *v == ADD_ON_MAGIC_NUMBER)(i)?; + let (i, fst) = F::parse_binary(i)?; + + let (i, _have_addon) = verify(parse_bin_bool, |v| *v)(i)?; + + let (i, have_addon1) = parse_bin_bool(i)?; + let (i, add_on_1) = if have_addon1 { + let (s, a) = AO1::parse_binary(i)?; + (s, Some(a)) + } else { + (i, None) + }; + let (i, have_addon2) = parse_bin_bool(i)?; + let (i, add_on_2) = if have_addon2 { + let (s, a) = AO2::parse_binary(i)?; + (s, Some(a)) + } else { + (i, None) + }; + + let add_on = (add_on_1.map(Arc::new), add_on_2.map(Arc::new)); + let fst_add_on = FstAddOn::new(fst, add_on, hdr.fst_type.s().clone()); + Ok((i, fst_add_on)) + } + + fn write_binary(&self, writer: &mut WB) -> Result<()> { + let hdr = FstHeader { + magic_number: FST_MAGIC_NUMBER, + fst_type: OpenFstString::new(&self.fst_type), + tr_type: OpenFstString::new(Tr::::tr_type()), + version: ADD_ON_FILE_VERSION, + flags: FstFlags::empty(), + properties: self.properties().bits() | EXPANDED, + start: -1, + num_states: 0, + num_trs: 0, + isymt: None, + osymt: None, + }; + hdr.write(writer)?; + write_bin_i32(writer, ADD_ON_MAGIC_NUMBER)?; + self.fst.write_binary(writer)?; + write_bin_bool(writer, true)?; + if let Some(add_on) = self.add_on.0.as_ref() { + write_bin_bool(writer, true)?; + add_on.write_binary(writer)?; + } else { + write_bin_bool(writer, false)?; + } + if let Some(add_on) = self.add_on.1.as_ref() { + write_bin_bool(writer, true)?; + add_on.write_binary(writer)?; + } else { + write_bin_bool(writer, false)?; + } + Ok(()) + } +} diff --git a/rustfst/src/algorithms/compose/interval_set.rs b/rustfst/src/algorithms/compose/interval_set.rs index d871f4005..bbe29c36e 100644 --- a/rustfst/src/algorithms/compose/interval_set.rs +++ b/rustfst/src/algorithms/compose/interval_set.rs @@ -1,10 +1,16 @@ use serde::{Deserialize, Serialize}; use std::cmp::Ordering; use std::collections::HashSet; +use std::io::Write; use std::slice::Iter as IterSlice; use std::vec::IntoIter as IntoIterVec; +use nom::IResult; +use nom::multi::count; use superslice::Ext; use unsafe_unwrap::UnsafeUnwrap; +use crate::NomCustomError; +use crate::parsers::{parse_bin_i32, parse_bin_i64, write_bin_i32, write_bin_i64}; +use crate::prelude::SerializeBinary; /// Half-open integral interval [a, b) of signed integers of type T. #[derive(PartialEq, Clone, Eq, Debug, Serialize, Deserialize)] @@ -47,6 +53,26 @@ impl Ord for IntInterval { } } +impl SerializeBinary for IntInterval { + fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> { + let (i, begin) = parse_bin_i32(i).map(|(s, v)| (s, v as usize))?; + let (i, end) = parse_bin_i32(i).map(|(s, v)| (s, v as usize))?; + Ok(( + i, + IntInterval { + begin, + end + }, + )) + } + + fn write_binary(&self, writer: &mut WB) -> anyhow::Result<()> { + write_bin_i32(writer, self.begin as i32)?; + write_bin_i32(writer, self.end as i32)?; + Ok(()) + } +} + /// Stores IntIntervals in a vector. In addition, keeps the count of points in /// all intervals. #[derive(Clone, PartialOrd, PartialEq, Debug)] @@ -95,11 +121,52 @@ impl VectorIntervalStore { } } +impl SerializeBinary for VectorIntervalStore { + fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> { + let (i, interval_count) = parse_bin_i64(i).map(|(s, v)| (s, v as usize))?; + let (i, intervals) = count(IntInterval::parse_binary, interval_count)(i)?; + let (i, store_count) = parse_bin_i32(i)?; + Ok(( + i, + VectorIntervalStore { + intervals, + count: Some(store_count as usize) + }, + )) + } + + fn write_binary(&self, writer: &mut WB) -> anyhow::Result<()> { + write_bin_i64(writer, self.intervals.len() as i64)?; + for interval in self.intervals.iter() { + interval.write_binary(writer)?; + } + write_bin_i32(writer, self.count.unwrap_or_default() as i32)?; + Ok(()) + } +} + #[derive(PartialOrd, PartialEq, Default, Clone, Debug)] pub struct IntervalSet { pub(crate) intervals: VectorIntervalStore, } +impl SerializeBinary for IntervalSet { + fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> { + let (i, intervals) = VectorIntervalStore::parse_binary(i)?; + Ok(( + i, + IntervalSet { + intervals + }, + )) + } + + fn write_binary(&self, writer: &mut WB) -> anyhow::Result<()> { + self.intervals.write_binary(writer)?; + Ok(()) + } +} + impl IntervalSet { pub fn len(&self) -> usize { self.intervals.len() @@ -149,7 +216,7 @@ impl IntervalSet { elt.begin + 1 == elt.end } - // Sorts, collapses overlapping and adjacent interals, and sets count. + // Sorts, collapses overlapping and adjacent intervals, and sets count. pub fn normalize(&mut self) { let intervals = &mut self.intervals.intervals; intervals.sort(); diff --git a/rustfst/src/algorithms/compose/label_reachable.rs b/rustfst/src/algorithms/compose/label_reachable.rs index 12698419a..701970426 100644 --- a/rustfst/src/algorithms/compose/label_reachable.rs +++ b/rustfst/src/algorithms/compose/label_reachable.rs @@ -1,8 +1,11 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::io::Write; use std::sync::Arc; use anyhow::Result; +use nom::IResult; +use nom::multi::count; use crate::algorithms::compose::{IntervalSet, StateReachable}; use crate::algorithms::tr_compares::{ILabelCompare, OLabelCompare}; @@ -11,7 +14,8 @@ use crate::fst_impls::VectorFst; use crate::fst_properties::FstProperties; use crate::fst_traits::{CoreFst, ExpandedFst, Fst, MutableFst}; use crate::semirings::Semiring; -use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED}; +use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED, NomCustomError}; +use crate::parsers::{parse_bin_bool, parse_bin_i64, parse_bin_i32, parse_bin_u32, SerializeBinary, write_bin_bool, write_bin_i64, write_bin_u32, write_bin_i32}; #[derive(Debug, Clone, PartialEq)] pub struct LabelReachableData { @@ -116,6 +120,73 @@ impl LabelReachableData { } } +fn parse_label_map(i: &[u8]) -> IResult<&[u8], HashMap, NomCustomError<&[u8]>> { + let mut stream = i; + let r = parse_bin_i64(stream).map(|(s, v)| (s, v as usize))?; + stream = r.0; + let map_size = r.1; + let mut map = HashMap::with_capacity(map_size); + for _ in 0..map_size { + let r = parse_bin_i32(stream).map(|(s, v)| (s, v as Label))?; + let key = r.1; + let r = parse_bin_i32(r.0).map(|(s, v)| (s, v as Label))?; + stream = r.0; + let val = r.1; + map.insert(key, val); + } + Ok((stream, map)) +} + +fn write_label_map(writer: &mut WB, map: &HashMap) -> Result<()> { + write_bin_i64(writer, map.len() as i64)?; + for (k, v) in map.iter() { + write_bin_i32(writer, *k as i32)?; + write_bin_i32(writer, *v as i32)?; + } + Ok(()) +} + +impl SerializeBinary for LabelReachableData { + fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> { + let (i, reach_input) = parse_bin_bool(i)?; + let (i, have_relabel_data) = parse_bin_bool(i)?; + let (i, label2index) = if have_relabel_data { + parse_label_map(i)? + } else { + (i, Default::default()) + }; + let (i, final_label) = parse_bin_u32(i).map(|(s, v)| (s, v as Label))?; + let (i, set_count) = parse_bin_i64(i).map(|(s, v)| (s, v as usize))?; + let (i, interval_sets) = count(IntervalSet::parse_binary, set_count)(i)?; + Ok(( + i, + LabelReachableData { + reach_input, + final_label, + label2index, + interval_sets + } + )) + } + + fn write_binary(&self, writer: &mut WB) -> Result<()> { + write_bin_bool(writer, self.reach_input)?; + // OpenFst checks keep_relabel_data here which is missing in this struct. + // Instead we check if we have any data in label2index; + let have_relabel_data = !self.label2index.is_empty(); + write_bin_bool(writer, have_relabel_data)?; + if have_relabel_data { + write_label_map(writer, &self.label2index)?; + } + write_bin_u32(writer, self.final_label as u32)?; + write_bin_i64(writer, self.interval_sets.len() as i64)?; + for interval_set in self.interval_sets.iter() { + interval_set.write_binary(writer)?; + } + Ok(()) + } +} + #[derive(Debug, Clone, PartialEq)] pub struct LabelReachable { data: Arc, diff --git a/rustfst/src/algorithms/compose/matcher_fst.rs b/rustfst/src/algorithms/compose/matcher_fst.rs index d3583eebf..f2f58d01d 100644 --- a/rustfst/src/algorithms/compose/matcher_fst.rs +++ b/rustfst/src/algorithms/compose/matcher_fst.rs @@ -17,13 +17,13 @@ use crate::semirings::Semiring; use crate::{StateId, SymbolTable}; #[derive(Clone, PartialEq, Debug)] -pub struct MatcherFst { - fst_add_on: FstAddOn>, Option>)>, +pub struct MatcherFst, B, M, T> { + fst_add_on: FstAddOn>, Option>)>, matcher: PhantomData, w: PhantomData<(W, B)>, } -impl MatcherFst { +impl, B, M, T> MatcherFst { pub fn fst(&self) -> &F { self.fst_add_on.fst() } @@ -42,6 +42,23 @@ impl MatcherFst { } } +impl MatcherFst +where + W: Semiring, + F: Fst, + B: Borrow, + M: LookaheadMatcher, +{ + pub fn new_with_fst_add_on(fst_add_on: FstAddOn>, Option>)>) -> Result { + Ok(Self { + fst_add_on, + matcher: PhantomData, + w: PhantomData + }) + } + +} + // TODO: To be generalized impl MatcherFst where @@ -58,8 +75,12 @@ where LabelLookAheadRelabeler::init(&mut fst, &mut add_on)?; let add_on = (add_on.0.map(Arc::new), add_on.1.map(Arc::new)); - - let fst_add_on = FstAddOn::new(fst, add_on); + let fst_type = if add_on.0.is_some() { + "ilabel_lookahead".to_string() + } else { + "olabel_lookahead".to_string() + }; + let fst_add_on = FstAddOn::new(fst, add_on, fst_type); Ok(Self { fst_add_on, matcher: PhantomData, @@ -81,8 +102,12 @@ where LabelLookAheadRelabeler::relabel(fst2, &mut add_on, relabel_input)?; let add_on = (add_on.0.map(Arc::new), add_on.1.map(Arc::new)); - - let fst_add_on = FstAddOn::new(fst, add_on); + let fst_type = if add_on.0.is_some() { + "ilabel_lookahead".to_string() + } else { + "olabel_lookahead".to_string() + }; + let fst_add_on = FstAddOn::new(fst, add_on, fst_type); Ok(Self { fst_add_on, matcher: PhantomData, @@ -91,8 +116,8 @@ where } } -impl, B: Borrow, M, T> CoreFst for MatcherFst { - type TRS = as CoreFst>::TRS; +impl, B: Borrow, M, T> CoreFst for MatcherFst { + type TRS = as CoreFst>::TRS; fn start(&self) -> Option { self.fst_add_on.start() @@ -135,7 +160,7 @@ impl, B: Borrow, M, T> CoreFst for MatcherFst, B: Borrow, M, T> StateIterator<'a> +impl<'a, W: Semiring, F: Fst, B: Borrow, M, T> StateIterator<'a> for MatcherFst { type Iter = >::Iter; @@ -148,10 +173,10 @@ impl<'a, W, F: StateIterator<'a>, B: Borrow, M, T> StateIterator<'a> impl<'a, W, F, B, M, T> FstIterator<'a, W> for MatcherFst where W: Semiring, - F: FstIterator<'a, W>, + F: Fst, B: Borrow, { - type FstIter = F::FstIter; + type FstIter = >::FstIter; fn fst_iter(&'a self) -> Self::FstIter { self.fst_add_on.fst_iter() @@ -207,12 +232,12 @@ where impl FstIntoIterator for MatcherFst where W: Semiring, - F: FstIntoIterator, + F: Fst + FstIntoIterator, B: Borrow + Debug + PartialEq + Clone, T: Debug, { type TrsIter = F::TrsIter; - type FstIter = F::FstIter; + type FstIter = >::FstIter; fn fst_into_iter(self) -> Self::FstIter { self.fst_add_on.fst_into_iter() diff --git a/rustfst/src/fst_impls/const_fst/serializable_fst.rs b/rustfst/src/fst_impls/const_fst/serializable_fst.rs index 6fc1f5eb9..104140616 100644 --- a/rustfst/src/fst_impls/const_fst/serializable_fst.rs +++ b/rustfst/src/fst_impls/const_fst/serializable_fst.rs @@ -1,9 +1,6 @@ -use std::fs::{read, File}; -use std::io::BufWriter; -use std::path::Path; +use std::io::Write; use std::sync::Arc; -use anyhow::Context; use anyhow::Result; use itertools::Itertools; use nom::bytes::complete::take; @@ -22,34 +19,52 @@ use crate::parsers::bin_fst::utils_parsing::{ parse_bin_fst_tr, parse_final_weight, parse_start_state, }; use crate::parsers::nom_utils::NomCustomError; -use crate::parsers::parse_bin_i32; +use crate::parsers::{parse_bin_i32, SerializeBinary}; use crate::parsers::text_fst::ParsedTextFst; use crate::parsers::write_bin_i32; use crate::semirings::SerializableSemiring; use crate::{Tr, EPS_LABEL}; -impl SerializableFst for ConstFst { - fn fst_type() -> String { - "const".to_string() - } +impl SerializeBinary for ConstFst { + fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> { + let stream_len = i.len(); - fn read>(path_bin_fst: P) -> Result { - let data = read(path_bin_fst.as_ref()).with_context(|| { - format!( - "Can't open ConstFst binary file : {:?}", - path_bin_fst.as_ref() - ) - })?; + let (mut i, hdr) = FstHeader::parse( + i, + CONST_MIN_FILE_VERSION, + Some(ConstFst::::fst_type()), + Tr::::tr_type(), + )?; + let aligned = hdr.version == CONST_ALIGNED_FILE_VERSION; + let pos = stream_len - i.len(); - let (_, parsed_fst) = parse_const_fst(&data) - .map_err(|_| format_err!("Error while parsing binary ConstFst"))?; + // Align input + if aligned && hdr.num_states > 0 && pos % CONST_ARCH_ALIGNMENT > 0 { + i = take(CONST_ARCH_ALIGNMENT - (pos % CONST_ARCH_ALIGNMENT))(i)?.0; + } + let (mut i, const_states) = count(parse_const_state, hdr.num_states as usize)(i)?; + let pos = stream_len - i.len(); - Ok(parsed_fst) + // Align input + if aligned && hdr.num_trs > 0 && pos % CONST_ARCH_ALIGNMENT > 0 { + i = take(CONST_ARCH_ALIGNMENT - (pos % CONST_ARCH_ALIGNMENT))(i)?.0; + } + let (i, const_trs) = count(parse_bin_fst_tr, hdr.num_trs as usize)(i)?; + + Ok(( + i, + ConstFst { + start: parse_start_state(hdr.start), + states: const_states, + trs: Arc::new(const_trs), + isymt: hdr.isymt, + osymt: hdr.osymt, + properties: FstProperties::from_bits_truncate(hdr.properties), + }, + )) } - fn write>(&self, path_bin_fst: P) -> Result<()> { - let mut file = BufWriter::new(File::create(path_bin_fst)?); - + fn write_binary(&self, writer: &mut WB) -> Result<()> { let mut flags = FstFlags::empty(); if self.input_symbols().is_some() { flags |= FstFlags::HAS_ISYMBOLS; @@ -72,28 +87,34 @@ impl SerializableFst for ConstFst { isymt: self.input_symbols().cloned(), osymt: self.output_symbols().cloned(), }; - hdr.write(&mut file)?; + hdr.write(writer)?; let zero = W::zero(); for const_state in &self.states { let f_weight = const_state.final_weight.as_ref().unwrap_or_else(|| &zero); - f_weight.write_binary(&mut file)?; + f_weight.write_binary(writer)?; - write_bin_i32(&mut file, const_state.pos as i32)?; - write_bin_i32(&mut file, const_state.ntrs as i32)?; - write_bin_i32(&mut file, const_state.niepsilons as i32)?; - write_bin_i32(&mut file, const_state.noepsilons as i32)?; + write_bin_i32(writer, const_state.pos as i32)?; + write_bin_i32(writer, const_state.ntrs as i32)?; + write_bin_i32(writer, const_state.niepsilons as i32)?; + write_bin_i32(writer, const_state.noepsilons as i32)?; } for tr in &*self.trs { - write_bin_i32(&mut file, tr.ilabel as i32)?; - write_bin_i32(&mut file, tr.olabel as i32)?; - tr.weight.write_binary(&mut file)?; - write_bin_i32(&mut file, tr.nextstate as i32)?; + write_bin_i32(writer, tr.ilabel as i32)?; + write_bin_i32(writer, tr.olabel as i32)?; + tr.weight.write_binary(writer)?; + write_bin_i32(writer, tr.nextstate as i32)?; } Ok(()) } +} + +impl SerializableFst for ConstFst { + fn fst_type() -> String { + "const".to_string() + } fn from_parsed_fst_text(mut parsed_fst_text: ParsedTextFst) -> Result { let start_state = parsed_fst_text.start(); @@ -207,43 +228,3 @@ fn parse_const_state( }, )) } - -fn parse_const_fst( - i: &[u8], -) -> IResult<&[u8], ConstFst, NomCustomError<&[u8]>> { - let stream_len = i.len(); - - let (mut i, hdr) = FstHeader::parse( - i, - CONST_MIN_FILE_VERSION, - ConstFst::::fst_type(), - Tr::::tr_type(), - )?; - let aligned = hdr.version == CONST_ALIGNED_FILE_VERSION; - let pos = stream_len - i.len(); - - // Align input - if aligned && hdr.num_states > 0 && pos % CONST_ARCH_ALIGNMENT > 0 { - i = take(CONST_ARCH_ALIGNMENT - (pos % CONST_ARCH_ALIGNMENT))(i)?.0; - } - let (mut i, const_states) = count(parse_const_state, hdr.num_states as usize)(i)?; - let pos = stream_len - i.len(); - - // Align input - if aligned && hdr.num_trs > 0 && pos % CONST_ARCH_ALIGNMENT > 0 { - i = take(CONST_ARCH_ALIGNMENT - (pos % CONST_ARCH_ALIGNMENT))(i)?.0; - } - let (i, const_trs) = count(parse_bin_fst_tr, hdr.num_trs as usize)(i)?; - - Ok(( - i, - ConstFst { - start: parse_start_state(hdr.start), - states: const_states, - trs: Arc::new(const_trs), - isymt: hdr.isymt, - osymt: hdr.osymt, - properties: FstProperties::from_bits_truncate(hdr.properties), - }, - )) -} diff --git a/rustfst/src/fst_impls/vector_fst/parse_const.rs b/rustfst/src/fst_impls/vector_fst/parse_const.rs index e5ca2e917..4e158cf5d 100644 --- a/rustfst/src/fst_impls/vector_fst/parse_const.rs +++ b/rustfst/src/fst_impls/vector_fst/parse_const.rs @@ -76,7 +76,7 @@ fn parse_const_fst( i, CONST_MIN_FILE_VERSION, // Intentional as the ConstFst file is being parsed. - ConstFst::::fst_type(), + Some(ConstFst::::fst_type()), Tr::::tr_type(), )?; let aligned = hdr.version == CONST_ALIGNED_FILE_VERSION; diff --git a/rustfst/src/fst_impls/vector_fst/serializable_fst.rs b/rustfst/src/fst_impls/vector_fst/serializable_fst.rs index f2d532875..110fa5a5c 100644 --- a/rustfst/src/fst_impls/vector_fst/serializable_fst.rs +++ b/rustfst/src/fst_impls/vector_fst/serializable_fst.rs @@ -1,9 +1,6 @@ -use std::fs::{read, File}; -use std::io::BufWriter; -use std::path::Path; +use std::io::Write; use std::sync::Arc; -use anyhow::Context; use anyhow::Result; use nom::multi::count; use nom::number::complete::le_i64; @@ -23,29 +20,30 @@ use crate::parsers::text_fst::ParsedTextFst; use crate::parsers::write_bin_i64; use crate::semirings::SerializableSemiring; use crate::{StateId, Tr, Trs, TrsVec, EPS_LABEL}; - -impl SerializableFst for VectorFst { - fn fst_type() -> String { - "vector".to_string() - } - - fn read>(path_bin_fst: P) -> Result { - let data = read(path_bin_fst.as_ref()).with_context(|| { - format!( - "Can't open VectorFst binary file : {:?}", - path_bin_fst.as_ref() - ) - })?; - - let (_, parsed_fst) = parse_vector_fst(&data) - .map_err(|e| format_err!("Error while parsing binary VectorFst : {:?}", e))?; - - Ok(parsed_fst) +use crate::prelude::SerializeBinary; + +impl SerializeBinary for VectorFst { + fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> { + let (i, header) = FstHeader::parse( + i, + VECTOR_MIN_FILE_VERSION, + Some(VectorFst::::fst_type()), + Tr::::tr_type(), + )?; + let (i, states) = count(parse_vector_fst_state, header.num_states as usize)(i)?; + Ok(( + i, + VectorFst { + start_state: parse_start_state(header.start), + states, + isymt: header.isymt, + osymt: header.osymt, + properties: FstProperties::from_bits_truncate(header.properties), + }, + )) } - fn write>(&self, path_bin_fst: P) -> Result<()> { - let mut file = BufWriter::new(File::create(path_bin_fst)?); - + fn write_binary(&self, writer: &mut WB) -> Result<()> { let num_trs: usize = (0..self.num_states()) .map(|s: usize| unsafe { self.num_trs_unchecked(s as StateId) }) .sum(); @@ -73,22 +71,28 @@ impl SerializableFst for VectorFst { isymt: self.input_symbols().cloned(), osymt: self.output_symbols().cloned(), }; - hdr.write(&mut file)?; + hdr.write(writer)?; // FstBody for state in 0..self.num_states() { let state = state as StateId; let f_weight = unsafe { self.final_weight_unchecked(state).unwrap_or_else(W::zero) }; - f_weight.write_binary(&mut file)?; - write_bin_i64(&mut file, unsafe { self.num_trs_unchecked(state) } as i64)?; + f_weight.write_binary(writer)?; + write_bin_i64(writer, unsafe { self.num_trs_unchecked(state) } as i64)?; for tr in unsafe { self.get_trs_unchecked(state).trs() } { - write_bin_fst_tr(&mut file, tr)?; + write_bin_fst_tr(writer, tr)?; } } Ok(()) } +} + +impl SerializableFst for VectorFst { + fn fst_type() -> String { + "vector".to_string() + } fn from_parsed_fst_text(parsed_fst_text: ParsedTextFst) -> Result { let start_state = parsed_fst_text.start(); @@ -155,25 +159,3 @@ fn parse_vector_fst_state( }, )) } - -fn parse_vector_fst( - i: &[u8], -) -> IResult<&[u8], VectorFst, NomCustomError<&[u8]>> { - let (i, header) = FstHeader::parse( - i, - VECTOR_MIN_FILE_VERSION, - VectorFst::::fst_type(), - Tr::::tr_type(), - )?; - let (i, states) = count(parse_vector_fst_state, header.num_states as usize)(i)?; - Ok(( - i, - VectorFst { - start_state: parse_start_state(header.start), - states, - isymt: header.isymt, - osymt: header.osymt, - properties: FstProperties::from_bits_truncate(header.properties), - }, - )) -} diff --git a/rustfst/src/fst_traits/serializable_fst.rs b/rustfst/src/fst_traits/serializable_fst.rs index f7fb43173..edf0caa12 100644 --- a/rustfst/src/fst_traits/serializable_fst.rs +++ b/rustfst/src/fst_traits/serializable_fst.rs @@ -2,7 +2,7 @@ use std::fs::File; use std::io::{BufWriter, LineWriter, Write}; use std::path::Path; -use anyhow::Result; +use anyhow::{Context, Result}; use unsafe_unwrap::UnsafeUnwrap; use crate::fst_traits::ExpandedFst; @@ -10,19 +10,36 @@ use crate::parsers::text_fst::ParsedTextFst; use crate::semirings::SerializableSemiring; use crate::Trs; use crate::{DrawingConfig, StateId}; +use crate::parsers::SerializeBinary; -/// Trait definining the methods an Fst must implement to be serialized and deserialized. -pub trait SerializableFst: ExpandedFst { - /// String identifying the type of the FST. Will be used when serialiing and +/// Trait defining the methods an Fst must implement to be serialized and deserialized. +pub trait SerializableFst: ExpandedFst + SerializeBinary { + /// String identifying the type of the FST. Will be used when serializing and /// deserializing an FST in binary format. fn fst_type() -> String; // BINARY /// Loads an FST from a file in binary format. - fn read>(path_bin_fst: P) -> Result; + fn read>(path_bin_fst: P) -> Result { + let data = std::fs::read(path_bin_fst.as_ref()).with_context(|| { + format!( + "Can't open {} fst binary file : {:?}", + Self::fst_type(), + path_bin_fst.as_ref() + ) + })?; + + let (_, parsed_fst) = Self::parse_binary(&data) + .map_err(|_| format_err!("Error while parsing binary ConstFst"))?; + + Ok(parsed_fst) + } /// Writes the FST to a file in binary format. - fn write>(&self, path_bin_fst: P) -> Result<()>; + fn write>(&self, path_bin_fst: P) -> Result<()> { + let mut file = BufWriter::new(File::create(path_bin_fst)?); + self.write_binary(&mut file) + } // TEXT diff --git a/rustfst/src/parsers/bin_fst/fst_header.rs b/rustfst/src/parsers/bin_fst/fst_header.rs index 7824cc646..f77239c1d 100644 --- a/rustfst/src/parsers/bin_fst/fst_header.rs +++ b/rustfst/src/parsers/bin_fst/fst_header.rs @@ -70,12 +70,12 @@ impl FstHeader { pub(crate) fn parse, S2: AsRef>( i: &[u8], min_file_version: i32, - fst_loading_type: S1, + fst_loading_type: Option, // Do not verify if None tr_loading_type: S2, ) -> IResult<&[u8], FstHeader, NomCustomError<&[u8]>> { let (i, magic_number) = verify(parse_bin_i32, |v: &i32| *v == FST_MAGIC_NUMBER)(i)?; let (i, fst_type) = verify(OpenFstString::parse, |v| { - v.s.as_str() == fst_loading_type.as_ref() + fst_loading_type.is_none() || v.s.as_str() == fst_loading_type.as_ref().unwrap().as_ref() })(i)?; let (i, tr_type) = verify(OpenFstString::parse, |v| { v.s.as_str() == tr_loading_type.as_ref() @@ -159,6 +159,10 @@ impl OpenFstString { write_bin_i32(file, self.n)?; file.write_all(self.s.as_bytes()).map_err(|e| e.into()) } + + pub(crate) fn s(&self) -> &String { + &self.s + } } impl Into for OpenFstString { diff --git a/rustfst/src/parsers/utils_parsing.rs b/rustfst/src/parsers/utils_parsing.rs index 979368878..ff90b8340 100644 --- a/rustfst/src/parsers/utils_parsing.rs +++ b/rustfst/src/parsers/utils_parsing.rs @@ -32,3 +32,8 @@ pub fn parse_bin_f32(i: &[u8]) -> IResult<&[u8], f32, NomCustomError<&[u8]>> { pub fn parse_bin_u8(i: &[u8]) -> IResult<&[u8], u8, NomCustomError<&[u8]>> { le_u8(i) } + +#[inline] +pub fn parse_bin_bool(i: &[u8]) -> IResult<&[u8], bool, NomCustomError<&[u8]>> { + le_u8(i).map(|(s, v)| (s, v != 0)) +} \ No newline at end of file diff --git a/rustfst/src/parsers/utils_serialization.rs b/rustfst/src/parsers/utils_serialization.rs index 4462c60f4..177e63193 100644 --- a/rustfst/src/parsers/utils_serialization.rs +++ b/rustfst/src/parsers/utils_serialization.rs @@ -31,3 +31,8 @@ pub fn write_bin_f32(file: &mut F, i: f32) -> Result<()> { pub(crate) fn write_bin_u8(file: &mut F, i: u8) -> Result<()> { file.write_all(&i.to_le_bytes()).map_err(|e| e.into()) } + +#[inline] +pub(crate) fn write_bin_bool(file: &mut F, i: bool) -> Result<()> { + file.write_all(&((if i { 1u8 } else { 0u8 }).to_le_bytes())).map_err(|e| e.into()) +}