|
| 1 | +// The struct `FieldInnerProductCols` is modified from succinctlabs/sp1 under MIT license |
| 2 | + |
| 3 | +// The MIT License (MIT) |
| 4 | + |
| 5 | +// Copyright (c) 2023 Succinct Labs |
| 6 | + |
| 7 | +// Permission is hereby granted, free of charge, to any person obtaining a copy |
| 8 | +// of this software and associated documentation files (the "Software"), to deal |
| 9 | +// in the Software without restriction, including without limitation the rights |
| 10 | +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 11 | +// copies of the Software, and to permit persons to whom the Software is |
| 12 | +// furnished to do so, subject to the following conditions: |
| 13 | + |
| 14 | +// The above copyright notice and this permission notice shall be included in |
| 15 | +// all copies or substantial portions of the Software. |
| 16 | + |
| 17 | +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 18 | +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 19 | +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 20 | +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 21 | +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 22 | +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
| 23 | +// THE SOFTWARE. |
| 24 | + |
| 25 | +use derive::AlignedBorrow; |
| 26 | +use ff_ext::{ExtensionField, SmallField}; |
| 27 | +use generic_array::{GenericArray, sequence::GenericSequence}; |
| 28 | +use gkr_iop::{circuit_builder::CircuitBuilder, error::CircuitBuilderError}; |
| 29 | +use multilinear_extensions::{Expression, ToExpr, WitIn}; |
| 30 | +use num::{BigUint, Zero}; |
| 31 | +use sp1_curves::{ |
| 32 | + params::{FieldParameters, Limbs}, |
| 33 | + polynomial::Polynomial, |
| 34 | +}; |
| 35 | +use std::fmt::Debug; |
| 36 | + |
| 37 | +use crate::{ |
| 38 | + gadgets::{ |
| 39 | + util::{compute_root_quotient_and_shift, split_u16_limbs_to_u8_limbs}, |
| 40 | + util_expr::eval_field_operation, |
| 41 | + }, |
| 42 | + witness::LkMultiplicity, |
| 43 | +}; |
| 44 | + |
| 45 | +/// A set of columns to compute `InnerProduct([a], [b])` where a, b are emulated elements. |
| 46 | +/// |
| 47 | +/// *Safety*: The `FieldInnerProductCols` asserts that `result = sum_i a_i * b_i mod M` where |
| 48 | +/// `M` is the modulus `P::modulus()` under the assumption that the length of `a` and `b` is small |
| 49 | +/// enough so that the vanishing polynomial has limbs bounded by the witness shift. It is the |
| 50 | +/// responsibility of the caller to ensure that the length of `a` and `b` is small enough. |
| 51 | +#[derive(Debug, Clone, AlignedBorrow)] |
| 52 | +#[repr(C)] |
| 53 | +pub struct FieldInnerProductCols<T, P: FieldParameters> { |
| 54 | + /// The result of `a inner product b`, where a, b are field elements |
| 55 | + pub result: Limbs<T, P::Limbs>, |
| 56 | + pub(crate) carry: Limbs<T, P::Limbs>, |
| 57 | + pub(crate) witness_low: Limbs<T, P::Witness>, |
| 58 | + pub(crate) witness_high: Limbs<T, P::Witness>, |
| 59 | +} |
| 60 | + |
| 61 | +impl<P: FieldParameters> FieldInnerProductCols<WitIn, P> { |
| 62 | + pub fn create<E: ExtensionField, NR, N>(cb: &mut CircuitBuilder<E>, name_fn: N) -> Self |
| 63 | + where |
| 64 | + NR: Into<String>, |
| 65 | + N: FnOnce() -> NR, |
| 66 | + { |
| 67 | + let name: String = name_fn().into(); |
| 68 | + Self { |
| 69 | + result: Limbs(GenericArray::generate(|_| { |
| 70 | + cb.create_witin(|| format!("{}_result", name)) |
| 71 | + })), |
| 72 | + carry: Limbs(GenericArray::generate(|_| { |
| 73 | + cb.create_witin(|| format!("{}_carry", name)) |
| 74 | + })), |
| 75 | + witness_low: Limbs(GenericArray::generate(|_| { |
| 76 | + cb.create_witin(|| format!("{}_witness_low", name)) |
| 77 | + })), |
| 78 | + witness_high: Limbs(GenericArray::generate(|_| { |
| 79 | + cb.create_witin(|| format!("{}_witness_high", name)) |
| 80 | + })), |
| 81 | + } |
| 82 | + } |
| 83 | +} |
| 84 | + |
| 85 | +impl<F: SmallField, P: FieldParameters> FieldInnerProductCols<F, P> { |
| 86 | + pub fn populate( |
| 87 | + &mut self, |
| 88 | + record: &mut LkMultiplicity, |
| 89 | + a: &[BigUint], |
| 90 | + b: &[BigUint], |
| 91 | + ) -> BigUint { |
| 92 | + let p_a_vec: Vec<Polynomial<F>> = a |
| 93 | + .iter() |
| 94 | + .map(|x| P::to_limbs_field::<F, _>(x).into()) |
| 95 | + .collect(); |
| 96 | + let p_b_vec: Vec<Polynomial<F>> = b |
| 97 | + .iter() |
| 98 | + .map(|x| P::to_limbs_field::<F, _>(x).into()) |
| 99 | + .collect(); |
| 100 | + |
| 101 | + let modulus = &P::modulus(); |
| 102 | + let inner_product = a |
| 103 | + .iter() |
| 104 | + .zip(b.iter()) |
| 105 | + .fold(BigUint::zero(), |acc, (c, d)| acc + c * d); |
| 106 | + |
| 107 | + let result = &(&inner_product % modulus); |
| 108 | + let carry = &((&inner_product - result) / modulus); |
| 109 | + assert!(result < modulus); |
| 110 | + assert!(carry < &(2u32 * modulus)); |
| 111 | + assert_eq!(carry * modulus, inner_product - result); |
| 112 | + |
| 113 | + let p_modulus: Polynomial<F> = P::to_limbs_field::<F, _>(modulus).into(); |
| 114 | + let p_result: Polynomial<F> = P::to_limbs_field::<F, _>(result).into(); |
| 115 | + let p_carry: Polynomial<F> = P::to_limbs_field::<F, _>(carry).into(); |
| 116 | + |
| 117 | + // Compute the vanishing polynomial. |
| 118 | + let p_inner_product = p_a_vec |
| 119 | + .into_iter() |
| 120 | + .zip(p_b_vec) |
| 121 | + .fold(Polynomial::<F>::new(vec![F::ZERO]), |acc, (c, d)| { |
| 122 | + acc + &c * &d |
| 123 | + }); |
| 124 | + let p_vanishing = p_inner_product - &p_result - &p_carry * &p_modulus; |
| 125 | + assert_eq!(p_vanishing.degree(), P::NB_WITNESS_LIMBS); |
| 126 | + |
| 127 | + let p_witness = compute_root_quotient_and_shift( |
| 128 | + &p_vanishing, |
| 129 | + P::WITNESS_OFFSET, |
| 130 | + P::NB_BITS_PER_LIMB as u32, |
| 131 | + P::NB_WITNESS_LIMBS, |
| 132 | + ); |
| 133 | + let (p_witness_low, p_witness_high) = split_u16_limbs_to_u8_limbs(&p_witness); |
| 134 | + |
| 135 | + self.result = p_result.into(); |
| 136 | + self.carry = p_carry.into(); |
| 137 | + self.witness_low = Limbs(p_witness_low.try_into().unwrap()); |
| 138 | + self.witness_high = Limbs(p_witness_high.try_into().unwrap()); |
| 139 | + |
| 140 | + // Range checks |
| 141 | + record.assert_byte_fields(&self.result.0); |
| 142 | + record.assert_byte_fields(&self.carry.0); |
| 143 | + record.assert_byte_fields(&self.witness_low.0); |
| 144 | + record.assert_byte_fields(&self.witness_high.0); |
| 145 | + |
| 146 | + result.clone() |
| 147 | + } |
| 148 | +} |
| 149 | + |
| 150 | +impl<Expr: Clone, P: FieldParameters> FieldInnerProductCols<Expr, P> { |
| 151 | + pub fn eval<E>( |
| 152 | + &self, |
| 153 | + builder: &mut CircuitBuilder<E>, |
| 154 | + a: &[impl Into<Polynomial<Expression<E>>> + Clone], |
| 155 | + b: &[impl Into<Polynomial<Expression<E>>> + Clone], |
| 156 | + ) -> Result<(), CircuitBuilderError> |
| 157 | + where |
| 158 | + E: ExtensionField, |
| 159 | + Expr: ToExpr<E, Output = Expression<E>>, |
| 160 | + Expression<E>: From<Expr>, |
| 161 | + { |
| 162 | + let p_a_vec: Vec<Polynomial<Expression<E>>> = a.iter().cloned().map(|x| x.into()).collect(); |
| 163 | + let p_b_vec: Vec<Polynomial<Expression<E>>> = b.iter().cloned().map(|x| x.into()).collect(); |
| 164 | + let p_result: Polynomial<Expression<E>> = self.result.clone().into(); |
| 165 | + let p_carry: Polynomial<Expression<E>> = self.carry.clone().into(); |
| 166 | + |
| 167 | + let p_zero = Polynomial::<Expression<E>>::new(vec![Expression::<E>::ZERO]); |
| 168 | + |
| 169 | + let p_inner_product = p_a_vec |
| 170 | + .iter() |
| 171 | + .zip(p_b_vec.iter()) |
| 172 | + .map(|(p_a, p_b)| p_a * p_b) |
| 173 | + .collect::<Vec<_>>() |
| 174 | + .iter() |
| 175 | + .fold(p_zero, |acc, x| acc + x); |
| 176 | + |
| 177 | + let p_inner_product_minus_result = &p_inner_product - &p_result; |
| 178 | + let p_limbs = |
| 179 | + Polynomial::from_iter(P::modulus_field_iter::<E::BaseField>().map(|x| x.expr())); |
| 180 | + let p_vanishing = &p_inner_product_minus_result - &(&p_carry * &p_limbs); |
| 181 | + |
| 182 | + let p_witness_low = self.witness_low.0.iter().into(); |
| 183 | + let p_witness_high = self.witness_high.0.iter().into(); |
| 184 | + |
| 185 | + eval_field_operation::<E, P>(builder, &p_vanishing, &p_witness_low, &p_witness_high)?; |
| 186 | + |
| 187 | + // Range checks for the result, carry, and witness columns. |
| 188 | + builder.assert_bytes(|| "field_inner_product result", &self.result.0)?; |
| 189 | + builder.assert_bytes(|| "field_inner_product carry", &self.carry.0)?; |
| 190 | + builder.assert_bytes(|| "field_inner_product witness_low", &self.witness_low.0)?; |
| 191 | + builder.assert_bytes(|| "field_inner_product witness_high", &self.witness_high.0) |
| 192 | + } |
| 193 | +} |
0 commit comments