Skip to content
Open
125 changes: 109 additions & 16 deletions rustfst/src/algorithms/compose/add_on.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,38 @@
use std::fmt::Debug;
use std::io::Write;
use std::marker::PhantomData;
use std::sync::Arc;

use anyhow::Result;
use nom::combinator::verify;
use nom::IResult;

use crate::{NomCustomError, StateId, SymbolTable, Tr};
use crate::fst_properties::FstProperties;
use crate::fst_traits::{CoreFst, ExpandedFst, Fst, FstIntoIterator, FstIterator, StateIterator};
use crate::fst_properties::properties::EXPANDED;
use crate::fst_traits::{CoreFst, ExpandedFst, Fst, FstIntoIterator, FstIterator, SerializableFst, StateIterator};
use crate::parsers::{parse_bin_bool, parse_bin_i32, write_bin_bool, write_bin_i32};
use crate::parsers::bin_fst::fst_header::{FST_MAGIC_NUMBER, FstFlags, FstHeader, OpenFstString};
use crate::prelude::{SerializableSemiring, SerializeBinary};
use crate::semirings::Semiring;
use crate::{StateId, SymbolTable};

/// Adds an object of type T to an FST.
/// The resulting type is a new FST implementation.
#[derive(Debug, PartialEq, Clone)]
pub struct FstAddOn<F, T> {
pub struct FstAddOn<W, F, T>
where
W: Semiring,
F: Fst<W>
{
pub(crate) fst: F,
pub(crate) add_on: T,
w: PhantomData<W>,
fst_type: String
}

impl<F, T> FstAddOn<F, T> {
pub fn new(fst: F, add_on: T) -> Self {
Self { fst, add_on }
impl<W: Semiring, F: Fst<W>, T> FstAddOn<W, F, T> {
pub fn new(fst: F, add_on: T, fst_type: String) -> Self {
Self { fst, add_on, w: PhantomData, fst_type }
}

pub fn fst(&self) -> &F {
Expand All @@ -34,7 +48,7 @@ impl<F, T> FstAddOn<F, T> {
}
}

impl<W: Semiring, F: CoreFst<W>, T> CoreFst<W> for FstAddOn<F, T> {
impl<W: Semiring, F: Fst<W>, T> CoreFst<W> for FstAddOn<W, F, T> {
type TRS = F::TRS;

fn start(&self) -> Option<StateId> {
Expand Down Expand Up @@ -78,27 +92,27 @@ impl<W: Semiring, F: CoreFst<W>, T> CoreFst<W> for FstAddOn<F, T> {
}
}

impl<'a, F: StateIterator<'a>, T> StateIterator<'a> for FstAddOn<F, T> {
impl<'a, W: Semiring, F: Fst<W>, T> StateIterator<'a> for FstAddOn<W, F, T> {
type Iter = <F as StateIterator<'a>>::Iter;

fn states_iter(&'a self) -> Self::Iter {
self.fst.states_iter()
}
}

impl<'a, W, F, T> FstIterator<'a, W> for FstAddOn<F, T>
impl<'a, W, F, T> FstIterator<'a, W> for FstAddOn<W, F, T>
where
W: Semiring + 'a,
F: FstIterator<'a, W>,
F: Fst<W>,
{
type FstIter = F::FstIter;
type FstIter = <F as FstIterator<'a, W>>::FstIter;

fn fst_iter(&'a self) -> Self::FstIter {
self.fst.fst_iter()
}
}

impl<W, F, T: Debug> Fst<W> for FstAddOn<F, T>
impl<W, F, T: Debug> Fst<W> for FstAddOn<W, F, T>
where
W: Semiring,
F: Fst<W>,
Expand Down Expand Up @@ -128,7 +142,7 @@ where
}
}

impl<W, F, T> ExpandedFst<W> for FstAddOn<F, T>
impl<W, F, T> ExpandedFst<W> for FstAddOn<W, F, T>
where
W: Semiring,
F: ExpandedFst<W>,
Expand All @@ -139,16 +153,95 @@ where
}
}

impl<W, F, T> FstIntoIterator<W> for FstAddOn<F, T>
impl<W, F, T> FstIntoIterator<W> for FstAddOn<W, F, T>
where
W: Semiring,
F: FstIntoIterator<W>,
F: FstIntoIterator<W> + Fst<W> ,
T: Debug,
{
type TrsIter = F::TrsIter;
type FstIter = F::FstIter;
type FstIter = <F as FstIntoIterator<W>>::FstIter;

fn fst_into_iter(self) -> Self::FstIter {
self.fst.fst_into_iter()
}
}

static ADD_ON_MAGIC_NUMBER: i32 = 446681434;
static ADD_ON_MIN_FILE_VERSION: i32 = 1;
static ADD_ON_FILE_VERSION: i32 = 1;

impl<W, F, AO1, AO2> SerializeBinary for FstAddOn<W, F, (Option<Arc<AO1>>, Option<Arc<AO2>>)>
where
W: SerializableSemiring,
F: SerializableFst<W>,
AO1: SerializeBinary + Debug + Clone + PartialEq,
AO2: SerializeBinary + Debug + Clone + PartialEq,
{
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {

let (i, hdr) = FstHeader::parse(
i,
ADD_ON_MIN_FILE_VERSION,
Option::<&str>::None,
Tr::<W>::tr_type(),
)?;

let (i, _) = verify(parse_bin_i32, |v: &i32| *v == ADD_ON_MAGIC_NUMBER)(i)?;
let (i, fst) = F::parse_binary(i)?;

let (i, _have_addon) = verify(parse_bin_bool, |v| *v)(i)?;

let (i, have_addon1) = parse_bin_bool(i)?;
let (i, add_on_1) = if have_addon1 {
let (s, a) = AO1::parse_binary(i)?;
(s, Some(a))
} else {
(i, None)
};
let (i, have_addon2) = parse_bin_bool(i)?;
let (i, add_on_2) = if have_addon2 {
let (s, a) = AO2::parse_binary(i)?;
(s, Some(a))
} else {
(i, None)
};

let add_on = (add_on_1.map(Arc::new), add_on_2.map(Arc::new));
let fst_add_on = FstAddOn::new(fst, add_on, hdr.fst_type.s().clone());
Ok((i, fst_add_on))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> Result<()> {
let hdr = FstHeader {
magic_number: FST_MAGIC_NUMBER,
fst_type: OpenFstString::new(&self.fst_type),
tr_type: OpenFstString::new(Tr::<W>::tr_type()),
version: ADD_ON_FILE_VERSION,
flags: FstFlags::empty(),
properties: self.properties().bits() | EXPANDED,
start: -1,
num_states: 0,
num_trs: 0,
isymt: None,
osymt: None,
};
hdr.write(writer)?;
write_bin_i32(writer, ADD_ON_MAGIC_NUMBER)?;
self.fst.write_binary(writer)?;
write_bin_bool(writer, true)?;
if let Some(add_on) = self.add_on.0.as_ref() {
write_bin_bool(writer, true)?;
add_on.write_binary(writer)?;
} else {
write_bin_bool(writer, false)?;
}
if let Some(add_on) = self.add_on.1.as_ref() {
write_bin_bool(writer, true)?;
add_on.write_binary(writer)?;
} else {
write_bin_bool(writer, false)?;
}
Ok(())
}
}
69 changes: 68 additions & 1 deletion rustfst/src/algorithms/compose/interval_set.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
use std::collections::HashSet;
use std::io::Write;
use std::slice::Iter as IterSlice;
use std::vec::IntoIter as IntoIterVec;
use nom::IResult;
use nom::multi::count;
use superslice::Ext;
use unsafe_unwrap::UnsafeUnwrap;
use crate::NomCustomError;
use crate::parsers::{parse_bin_i32, parse_bin_i64, write_bin_i32, write_bin_i64};
use crate::prelude::SerializeBinary;

/// Half-open integral interval [a, b) of signed integers of type T.
#[derive(PartialEq, Clone, Eq, Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -47,6 +53,26 @@ impl Ord for IntInterval {
}
}

impl SerializeBinary for IntInterval {
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
let (i, begin) = parse_bin_i32(i).map(|(s, v)| (s, v as usize))?;
let (i, end) = parse_bin_i32(i).map(|(s, v)| (s, v as usize))?;
Ok((
i,
IntInterval {
begin,
end
},
))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> anyhow::Result<()> {
write_bin_i32(writer, self.begin as i32)?;
write_bin_i32(writer, self.end as i32)?;
Ok(())
}
}

/// Stores IntIntervals in a vector. In addition, keeps the count of points in
/// all intervals.
#[derive(Clone, PartialOrd, PartialEq, Debug)]
Expand Down Expand Up @@ -95,11 +121,52 @@ impl VectorIntervalStore {
}
}

impl SerializeBinary for VectorIntervalStore {
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
let (i, interval_count) = parse_bin_i64(i).map(|(s, v)| (s, v as usize))?;
let (i, intervals) = count(IntInterval::parse_binary, interval_count)(i)?;
let (i, store_count) = parse_bin_i32(i)?;
Ok((
i,
VectorIntervalStore {
intervals,
count: Some(store_count as usize)
},
))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> anyhow::Result<()> {
write_bin_i64(writer, self.intervals.len() as i64)?;
for interval in self.intervals.iter() {
interval.write_binary(writer)?;
}
write_bin_i32(writer, self.count.unwrap_or_default() as i32)?;
Ok(())
}
}

#[derive(PartialOrd, PartialEq, Default, Clone, Debug)]
pub struct IntervalSet {
pub(crate) intervals: VectorIntervalStore,
}

impl SerializeBinary for IntervalSet {
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
let (i, intervals) = VectorIntervalStore::parse_binary(i)?;
Ok((
i,
IntervalSet {
intervals
},
))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> anyhow::Result<()> {
self.intervals.write_binary(writer)?;
Ok(())
}
}

impl IntervalSet {
pub fn len(&self) -> usize {
self.intervals.len()
Expand Down Expand Up @@ -149,7 +216,7 @@ impl IntervalSet {
elt.begin + 1 == elt.end
}

// Sorts, collapses overlapping and adjacent interals, and sets count.
// Sorts, collapses overlapping and adjacent intervals, and sets count.
pub fn normalize(&mut self) {
let intervals = &mut self.intervals.intervals;
intervals.sort();
Expand Down
73 changes: 72 additions & 1 deletion rustfst/src/algorithms/compose/label_reachable.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::io::Write;
use std::sync::Arc;

use anyhow::Result;
use nom::IResult;
use nom::multi::count;

use crate::algorithms::compose::{IntervalSet, StateReachable};
use crate::algorithms::tr_compares::{ILabelCompare, OLabelCompare};
Expand All @@ -11,7 +14,8 @@ use crate::fst_impls::VectorFst;
use crate::fst_properties::FstProperties;
use crate::fst_traits::{CoreFst, ExpandedFst, Fst, MutableFst};
use crate::semirings::Semiring;
use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED};
use crate::{Label, StateId, Tr, Trs, EPS_LABEL, NO_LABEL, UNASSIGNED, NomCustomError};
use crate::parsers::{parse_bin_bool, parse_bin_i64, parse_bin_i32, parse_bin_u32, SerializeBinary, write_bin_bool, write_bin_i64, write_bin_u32, write_bin_i32};

#[derive(Debug, Clone, PartialEq)]
pub struct LabelReachableData {
Expand Down Expand Up @@ -116,6 +120,73 @@ impl LabelReachableData {
}
}

fn parse_label_map(i: &[u8]) -> IResult<&[u8], HashMap<Label, Label>, NomCustomError<&[u8]>> {
let mut stream = i;
let r = parse_bin_i64(stream).map(|(s, v)| (s, v as usize))?;
stream = r.0;
let map_size = r.1;
let mut map = HashMap::with_capacity(map_size);
for _ in 0..map_size {
let r = parse_bin_i32(stream).map(|(s, v)| (s, v as Label))?;
let key = r.1;
let r = parse_bin_i32(r.0).map(|(s, v)| (s, v as Label))?;
stream = r.0;
let val = r.1;
map.insert(key, val);
}
Ok((stream, map))
}

fn write_label_map<WB: Write>(writer: &mut WB, map: &HashMap<Label, Label>) -> Result<()> {
write_bin_i64(writer, map.len() as i64)?;
for (k, v) in map.iter() {
write_bin_i32(writer, *k as i32)?;
write_bin_i32(writer, *v as i32)?;
}
Ok(())
}

impl SerializeBinary for LabelReachableData {
fn parse_binary(i: &[u8]) -> IResult<&[u8], Self, NomCustomError<&[u8]>> {
let (i, reach_input) = parse_bin_bool(i)?;
let (i, have_relabel_data) = parse_bin_bool(i)?;
let (i, label2index) = if have_relabel_data {
parse_label_map(i)?
} else {
(i, Default::default())
};
let (i, final_label) = parse_bin_u32(i).map(|(s, v)| (s, v as Label))?;
let (i, set_count) = parse_bin_i64(i).map(|(s, v)| (s, v as usize))?;
let (i, interval_sets) = count(IntervalSet::parse_binary, set_count)(i)?;
Ok((
i,
LabelReachableData {
reach_input,
final_label,
label2index,
interval_sets
}
))
}

fn write_binary<WB: Write>(&self, writer: &mut WB) -> Result<()> {
write_bin_bool(writer, self.reach_input)?;
// OpenFst checks keep_relabel_data here which is missing in this struct.
// Instead we check if we have any data in label2index;
let have_relabel_data = !self.label2index.is_empty();
write_bin_bool(writer, have_relabel_data)?;
if have_relabel_data {
write_label_map(writer, &self.label2index)?;
}
write_bin_u32(writer, self.final_label as u32)?;
write_bin_i64(writer, self.interval_sets.len() as i64)?;
for interval_set in self.interval_sets.iter() {
interval_set.write_binary(writer)?;
}
Ok(())
}
}

#[derive(Debug, Clone, PartialEq)]
pub struct LabelReachable {
data: Arc<LabelReachableData>,
Expand Down
Loading