diff --git a/README.md b/README.md index 3cc008a2f..a64e7cf13 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ This library is released under the MIT License and the Apache v2 License (see [L This repository contains two Rust crates: * [`ark-snark`](snark): Provides generic traits for zkSNARKs -* [`ark-relations`](relations): Provides generic traits for NP relations used in programming zkSNARKs, such as R1CS +* [`ark-relations`](relations): Provides generic traits for NP relations used in programming zkSNARKs, such as R1CS and Arithmetic Circuits ## Overview diff --git a/circom/cube.circom b/circom/cube.circom new file mode 100644 index 000000000..73bbd5882 --- /dev/null +++ b/circom/cube.circom @@ -0,0 +1,11 @@ +template Cube() { + + signal input x; + signal x2; + + x2 <== x * x; + + x * x2 === 27; +} + +component main = Cube(); diff --git a/circom/cube.r1cs b/circom/cube.r1cs new file mode 100644 index 000000000..cf1bc4335 Binary files /dev/null and b/circom/cube.r1cs differ diff --git a/circom/cube.wasm b/circom/cube.wasm new file mode 100644 index 000000000..e35b15f18 Binary files /dev/null and b/circom/cube.wasm differ diff --git a/circom/multiplication.circom b/circom/multiplication.circom new file mode 100644 index 000000000..676c6ecfd --- /dev/null +++ b/circom/multiplication.circom @@ -0,0 +1,9 @@ +template Multiplication() { + signal input s1; + signal input s2; + signal output y; + + y <== s1 * s2; +} + +component main = Multiplication(); diff --git a/circom/multiplication.r1cs b/circom/multiplication.r1cs new file mode 100644 index 000000000..e61b9b091 Binary files /dev/null and b/circom/multiplication.r1cs differ diff --git a/circom/multiplication.wasm b/circom/multiplication.wasm new file mode 100644 index 000000000..248fdfb46 Binary files /dev/null and b/circom/multiplication.wasm differ diff --git a/relations/Cargo.toml b/relations/Cargo.toml index 7af412eeb..0d19e61c4 100644 --- a/relations/Cargo.toml +++ b/relations/Cargo.toml @@ -17,10 +17,16 @@ ark-ff.workspace = true ark-std.workspace = true tracing.workspace = true tracing-subscriber = { workspace = true, optional = true } +itertools = { version = "0.13.0", default-features = false } [dev-dependencies] ark-test-curves = { workspace = true, features = [ "bls12_381_scalar_field" ] } +tokio = { version = "1", features = [ "full" ] } +ark-ec = { version = "0.5.0", default-features = false } +ark-circom = { git = "https://github.com/arkworks-rs/circom-compat", branch = "release-0.5" } +ark-bn254 = { version = "0.5.0", default-features = false, features = [ "curve" ] } +ark-bls12-377 = { version = "0.5.0", default-features = false, features = [ "curve" ] } [features] -default = [] -std = [ "ark-std/std", "ark-ff/std", "tracing-subscriber", "tracing/std" ] +default = [ "itertools/use_alloc" ] +std = [ "ark-std/std", "ark-ff/std", "tracing-subscriber", "tracing/std", "ark-ec/std", "itertools/use_alloc" ] diff --git a/relations/README.md b/relations/README.md new file mode 100644 index 000000000..c097d9722 --- /dev/null +++ b/relations/README.md @@ -0,0 +1,61 @@ +## Arithmetic circuits + +Arithmetic circuits are a way to represent polynomial expressions as a sequence of addition and multiplication operations, which we conceptualise as gates. These operations take the form of nodes in a directed acyclic graph, where the edges represent the flow of data between the nodes. In this implementation, gates have a fan-in of two (i.e. two inputs) and arbitrary fan-out (i.e. their output can be the output of arbitrarily many other gates). We represent arithmetic circuits as an array of `Constants`, `Variables`, `Add` gates and `Mul` gates. Gates have a left and right input, which are indices for the respective nodes in the array. In order to directly construct an arithmetic circuit, the user must instantiate an empty `ArithmeticCircuit` struct and mutate it via its public methods (e.g. `new_variable`, `add`, `mul`), each of which returns the index of the new node in the array. For example, an arithmetic circuit expressing the computation `2 * x + y` can be constructed as follows: + +```rust + let mut circuit = ArithmeticCircuit::new(); + let two = circuit.constant(F::from(2)); + let x = circuit.new_variable(); + let y = circuit.new_variable(); + let two_x = circuit.mul(two, x); + let result = circuit.add(two_x, y); +``` + +Variables can also be given labels for easier value assignment and tracking: + +```rust + let x = circuit.new_variable_with_label("x"); + let y = circuit.new_variable_with_label("y"); +``` + +We note that there is only one `Constant` node for each field element `v` appearing in the computation: subsequent calls to `circuit.constant(v)` will point to the same node and therefore can be transparently made without incuring unnecessary spatial costs. + + +## Arithmetic expressions + +We also provide tooling to generate arithmetic circuits from user-friendly `Expression`s. The latter allow the programmer to write mathematical formulas in a more human-readable way (e.g., `a + b * c`) and can subsequently be converted into an `ArithmeticCircuit`. For comparison, an arithmetic circuit for the same computation `2 * x + y` can be constructed as follows: + +```rust + let x = Expression::variable("x"); + let y = Expression::variable("y"); + let result = 2 * x + y; + let circuit = result.to_arithmetic_circuit(); +``` + +In the case of expressions, variable labels are indispensable anchors to each individual variable after the expression is compiled into a circuit. + +Due to Rust's borrow-checker, an expression needs to be cloned if it is used more than once in the same line. For instance, the following +```rust + let expression = x * x + y; +``` +will not compile, the correct syntax being: +```rust + let expression = x.clone() * x + y; +``` +We note that cloning expressions is very cheap, since they are implemented using the `Rc` struct. This and other pain points of expression sysntax may be ironed out in the future. + +## R1CS to arithmetic circuits + +Our implementation also includes the method `from_constraint_system`, which allows the user to convert an Arkworks `ConstraintSystem` (i.e., an R1CS) into an `ArithmeticCircuit`. The method takes as input a `ConstraintSystem` struct, which contains the R1CS matrices `A`, `B`, and `C`. A `ConstraintSystem` can be obtained from the circom generated `.r1cs` and `.wasm` files,via the `read_constraint_system` method. + +## Generating R1CS files +In order to generate an `.r1cs` file from a `.circom` one (with name, say, `NAME`), use +``` + circom NAME.circom --r1cs +``` + +In order to generate a `.wasm` file from a `.circom` one, use +``` + circom NAME.circom --wasm +``` +and take the `.wasm` file from within the newly created folder. \ No newline at end of file diff --git a/relations/src/arithmetic_circuit/mod.rs b/relations/src/arithmetic_circuit/mod.rs new file mode 100644 index 000000000..06cb8574b --- /dev/null +++ b/relations/src/arithmetic_circuit/mod.rs @@ -0,0 +1,680 @@ +//! This module contains the core functionality for arithmetic circuits. +use crate::r1cs::{ConstraintMatrices, ConstraintSystem}; +use ark_ff::{BigInteger, BigInteger256, PrimeField}; +use ark_std::{ + assert, + clone::Clone, + cmp::PartialEq, + collections::BTreeMap, + convert::From, + fmt::{Display, Formatter, Result}, + format, + iter::{IntoIterator, Iterator}, + option::{ + Option, + Option::{None, Some}, + }, + panic, + prelude::rust_2021::{derive, Debug}, + string::{String, ToString}, + unreachable, + vec::Vec, +}; + +#[cfg(test)] +mod tests; + +/// Represents a node in an arithmetic circuit. A node is either a variable, a +/// constant, or a gate. +#[derive(Debug, Clone, PartialEq)] +pub enum Node { + /// Variable set individually for each execution + // Since no two variables have the same label, no memory cost is incurred + // due to owning the string as opposed to a &'a str + Variable(String), + /// Constant across all executions + Constant(F), + /// Addition gate with indices of its left and right input within a larger + /// circuit + Add(usize, usize), + /// Multiplication gate with indices of its left and right input within a + /// larger circuit + Mul(usize, usize), +} + +#[derive(Debug, Clone, PartialEq)] + +/// Represents an arithmetic circuit over a field F. An arithmetic circuit is a +/// directed acyclic graph where nodes are either variables, constants, or +/// gates for addition and multiplication. +pub struct ArithmeticCircuit { + /// List of nodes of the circuit + pub nodes: Vec>, + /// Hash map of constants defined in the circuit in order to avoid + /// duplication + pub constants: BTreeMap, + /// Map from variable labels to node indices + pub variables: BTreeMap, + /// Big-endian bit decomposition of F::MODULUS - 1, without initial zeros + pub(crate) unit_group_bits: Option>, +} + +impl ArithmeticCircuit { + /// Creates a new, empty arithmetic circuit. + pub fn new() -> Self { + Self { + nodes: Vec::new(), + constants: BTreeMap::new(), + variables: BTreeMap::new(), + unit_group_bits: Option::None, + } + } + + /// Returns the number of nodes in the circuit. + pub fn num_nodes(&self) -> usize { + self.nodes.len() + } + + /// Returns the number of constants in the circuit. + pub fn num_constants(&self) -> usize { + self.constants.len() + } + + /// Returns the number of variables in the circuit. + pub fn num_variables(&self) -> usize { + self.variables.len() + } + + /// Returns the index of the last node in the circuit. + pub fn last(&self) -> usize { + self.nodes.len() - 1 + } + + /// Returns the number of addition and multiplication gates in the circuit. + pub fn num_gates(&self) -> usize { + self.nodes + .iter() + .filter(|node| match node { + Node::Add(..) | Node::Mul(..) => true, + _ => false, + }) + .count() + } + + /// Returns existing constant with value `value` if it exists, or creates a + /// new one and returns its index if it does not + pub fn constant(&mut self, value: F) -> usize { + if let Some(index) = self.constants.get(&value) { + *index + } else { + let index = self.push_node(Node::Constant(value)); + self.constants.insert(value, index); + index + } + } + + /// Creates a variable with the given label + /// + /// # Panics + /// Panics if the circuit already contains a variable with name `var_N` + // Receiving &str to ease caller syntax + pub fn new_variable_with_label(&mut self, label: &str) -> usize { + let index = self.push_node(Node::Variable(label.to_string())); + + if self.variables.insert(label.to_string(), index).is_some() { + panic!("Variable label already in use: {label}"); + } + + index + } + + /// Creates a variable with the label `var_N`, where `N` is the number of + /// variables in the circuit + /// + /// # Panics + /// Panics if the circuit already contains a variable with name `var_N` + pub fn new_variable(&mut self) -> usize { + self.new_variable_with_label(&format!("var_{}", self.num_variables())) + } + + /// Creates `num` new variables + pub fn new_variables(&mut self, num: usize) -> Vec { + (0..num).map(|_| self.new_variable()).collect() + } + + /// Returns the index of the variable with label `label` + pub fn get_variable(&self, label: &str) -> usize { + *self.variables.get(label).expect("Variable not in circuit") + } + + /// Adds the two nodes, checking that they are in the circuit + pub fn add(&mut self, left: usize, right: usize) -> usize { + let length = self.nodes.len(); + assert!(left < length, "Left operand to Add not in circuit:"); + assert!(right < length, "Right operand to Add not in circuit:"); + + self.push_node(Node::Add(left, right)) + } + + /// Multiplies the two nodes without checking that they are in the circuit + pub fn mul_unchecked(&mut self, left: usize, right: usize) -> usize { + self.push_node(Node::Mul(left, right)) + } + + /// Multiplies the two nodes, checking that they are in the circuit + pub fn mul(&mut self, left: usize, right: usize) -> usize { + let length = self.nodes.len(); + assert!(left < length, "Left operand to Mul not in circuit:"); + assert!(right < length, "Right operand to Mul not in circuit:"); + + self.push_node(Node::Mul(left, right)) + } + + /// Adds all nodes in the given iterator + pub fn add_nodes(&mut self, indices: impl IntoIterator) -> usize { + indices + .into_iter() + .reduce(|acc, index| self.add(acc, index)) + .unwrap() + } + + /// Multiplies all nodes in the given list + pub fn mul_nodes(&mut self, indices: impl IntoIterator) -> usize { + indices + .into_iter() + .reduce(|acc, index| self.mul(acc, index)) + .unwrap() + } + + /// Computes node^exponent, where exponent is a BigUint + pub fn pow_bigint(&mut self, node: usize, exponent: BigInteger256) -> usize { + assert!( + node < self.num_nodes(), + "Base node ({node}) not in the circuit (which contains {} nodes)", + self.num_nodes() + ); + + let binary_decomposition = exponent + .to_bits_be() + .into_iter() + .map(|b| b) + .skip_while(|b| !b) + .collect::>(); + + self.pow_binary(node, &binary_decomposition) + } + + /// Computes node^exponent, where exponent is a usize + pub fn pow(&mut self, node: usize, exponent: usize) -> usize { + self.pow_bigint(node, BigInteger256::from(exponent as u64)) + } + + // Standard square-and-multiply. The first bit is always one, so we can + // skip it and initialise the accumulator to node instead of 1 + fn pow_binary(&mut self, node: usize, binary_decomposition: &Vec) -> usize { + let mut current = node; + + for bit in binary_decomposition.iter().skip(1) { + current = self.mul_unchecked(current, current); + + if *bit { + current = self.mul_unchecked(current, node); + } + } + + current + } + + /// Computes the node x^(F::MODULUS - 1), which is 0 if x = 0 and 1 + /// otherwise + pub fn indicator(&mut self, node: usize) -> usize { + let unit_group_bits = self + .unit_group_bits + .get_or_insert_with(|| { + let mod_minus_one: F::BigInt = (-F::ONE).into(); + mod_minus_one + .to_bits_be() + .into_iter() + .skip_while(|b| !b) + .collect() + }) + .clone(); + + self.pow_binary(node, &unit_group_bits) + } + + /// Computes the negation of the given node + pub fn minus(&mut self, node: usize) -> usize { + let minus_one = self.constant(-F::ONE); + self.mul(minus_one, node) + } + + /// Computes the scalar product of two vectors of nodes. Does NOT perform + /// optimisations by, for instance, skipping multiplication of the form 1 * + /// x or 0 * x, or omitting addition of zero terms. + pub fn scalar_product( + &mut self, + left: impl IntoIterator, + right: impl IntoIterator, + ) -> usize { + let products = left + .into_iter() + .zip(right) + .map(|(l, r)| self.mul_unchecked(l, r)) + .collect::>(); + self.add_nodes(products) + } + + fn push_node(&mut self, node: Node) -> usize { + self.nodes.push(node); + self.nodes.len() - 1 + } + + // Auxiliary recursive function which evaluation_trace wraps around + fn inner_evaluate(&self, node_index: usize, node_assignments: &mut Vec>) { + if node_assignments[node_index].is_some() { + return; + } + + let node = &self.nodes[node_index]; + + match node { + Node::Variable(_) => panic!("Uninitialised variable"), + Node::Constant(_) => panic!("Uninitialised constant"), + Node::Add(left, right) | Node::Mul(left, right) => { + self.inner_evaluate(*left, node_assignments); + self.inner_evaluate(*right, node_assignments); + + let left_value = node_assignments[*left].unwrap(); + let right_value = node_assignments[*right].unwrap(); + + node_assignments[node_index] = Some(match node { + Node::Add(..) => left_value + right_value, + Node::Mul(..) => left_value * right_value, + _ => unreachable!(), + }); + }, + } + } + + // ************************ Evaluation functions *************************** + + /// Evaluate all nodes required to compute the output node, returning the + /// full vector of intermediate node values. Nodes not involved in the + /// computation (and not passed as part of the variable assignment) are left + /// as None + /// + /// # Panics + /// Panics if a variable index is not found in the circuit. + pub fn evaluation_trace(&self, vars: Vec<(usize, F)>, node: usize) -> Vec> { + let mut node_assignments = self + .nodes + .iter() + .map(|node| { + if let Node::Constant(c) = node { + Some(*c) + } else { + None + } + }) + .collect::>>(); + + // This does not check (for efficiency reasons) that each variable was + // supplied with only one value: in the case of duplicates, the latest + // one in the list is used + for (index, value) in vars { + if let Node::Variable(_) = self.nodes[index] { + node_assignments[index] = Some(value); + } else { + panic!("Value supplied for non-variable node"); + } + } + + self.inner_evaluate(node, &mut node_assignments); + + node_assignments + } + + /// Similar to `evaluation_trace`, but using variable labels instead of + /// node indices. + /// + /// # Panics + /// Panics if a variable label is not found in the circuit. + pub fn evaluation_trace_with_labels( + &self, + vars: Vec<(&str, F)>, + node: usize, + ) -> Vec> { + let vars = vars + .into_iter() + .map(|(label, value)| (self.get_variable(label), value)) + .collect::>(); + + self.evaluation_trace(vars, node) + } + + /// Similar to `evaluation_trace`, but evaluating multiple output nodes. + /// This function is useful for evaluating constraints, in the case of + /// rank-1 constraints, this corresponds to evaluating the constraint for + /// multiple different assignments to the instance and witness variables + /// simultaneously. Returns a vector of `Option` values which are `Some` + /// if the node is set, and `None` otherwise. + /// + /// # Panics + /// Panics if a variable label is not found in the circuit. + pub fn evaluation_trace_multioutput( + &self, + vars: Vec<(usize, F)>, + outputs: &Vec, + ) -> Vec> { + let mut node_assignments = self + .nodes + .iter() + .map(|node| { + if let Node::Constant(c) = node { + Some(*c) + } else { + None + } + }) + .collect::>>(); + + // This does not check (for efficiency reasons) that each variable was + // supplied with only one value: in the case of duplicates, the latest + // one in the list is used + for (index, value) in vars { + if let Node::Variable(_) = self.nodes[index] { + node_assignments[index] = Some(value); + } else { + panic!("Value supplied for non-variable node"); + } + } + + outputs + .iter() + .for_each(|node| self.inner_evaluate(*node, &mut node_assignments)); + + node_assignments + } + + /// Similar to `evaluation_trace_multioutput`, but using variable labels + /// instead of node indices. + /// + /// # Panics + /// Panics if a variable label is not found in the circuit. + pub fn evaluation_trace_multioutput_with_labels( + &self, + vars: Vec<(&str, F)>, + outputs: &Vec, + ) -> Vec> { + let vars = vars + .into_iter() + .map(|(label, value)| (self.get_variable(label), value)) + .collect::>(); + + self.evaluation_trace_multioutput(vars, outputs) + } + + /// Evaluates a single node, returning the value of the node. + /// + /// # Panics + /// Panics if the node is not assigned a value. + pub fn evaluate_node(&self, vars: Vec<(usize, F)>, node: usize) -> F { + self.evaluation_trace(vars, node)[node].unwrap() + } + + /// Similar to `evaluate_node`, but using variable labels instead of node + /// indices. + /// + /// # Panics + /// Panics if the node is not assigned a value. + pub fn evaluate_node_with_labels(&self, vars: Vec<(&str, F)>, node: usize) -> F { + self.evaluation_trace_with_labels(vars, node)[node].unwrap() + } + + /// Similar to `evaluation_trace_multioutput`, but returning the values of + /// only the output nodes. + pub fn evaluate_multioutput(&self, vars: Vec<(usize, F)>, outputs: &Vec) -> Vec { + self.evaluation_trace_multioutput(vars, outputs) + .into_iter() + .enumerate() + .filter_map(|(i, v)| if outputs.contains(&i) { v } else { None }) + .collect() + } + + /// Similar to `evaluate_multioutput`, but using variable labels instead of + /// node indices. + pub fn evaluate_multioutput_with_labels( + &self, + vars: Vec<(&str, F)>, + outputs: &Vec, + ) -> Vec { + self.evaluation_trace_multioutput_with_labels(vars, outputs) + .into_iter() + .enumerate() + .filter_map(|(i, v)| if outputs.contains(&i) { v } else { None }) + .collect() + } + + /// Evaluates the circuit at the last node, returning the value of the last + /// node. + pub fn evaluate(&self, vars: Vec<(usize, F)>) -> F { + self.evaluate_node(vars, self.last()) + } + + /// Similar to `evaluate`, but using variable labels instead of node + /// indices. + pub fn evaluate_with_labels(&self, vars: Vec<(&str, F)>) -> F { + self.evaluate_node_with_labels(vars, self.last()) + } + + /// Prints the evaluation trace for a given variable assignment and node + #[cfg(feature = "std")] + pub fn print_evaluation_trace(&self, var_assignment: Vec<(usize, F)>, node: usize) { + println!("Arithmetic circuit with {} nodes:", self.num_nodes()); + + let evaluations = self.evaluation_trace(var_assignment, node); + + for (index, (node, value)) in self.nodes.iter().zip(evaluations.iter()).enumerate() { + if let Node::Constant(c) = node { + println!("\t{index}: Constant = {c:?}"); + } else { + let value = if let Some(v) = value { + format!("{v:?}") + } else { + "not set".to_string() + }; + + println!("\t{index}: {node} = {value}"); + } + } + } + + /// Similar to `print_evaluation_trace`, but allowing for multiple outputs + #[cfg(feature = "std")] + pub fn print_evaluation_trace_multioutput( + &self, + var_assignment: Vec<(usize, F)>, + outputs: &Vec, + ) { + println!("Arithmetic circuit with {} nodes:", self.num_nodes()); + + let evaluations = self.evaluation_trace_multioutput(var_assignment, outputs); + + for (index, (node, value)) in self.nodes.iter().zip(evaluations.iter()).enumerate() { + if let Node::Constant(c) = node { + println!("\t{index}: Constant = {c:?}"); + } else { + let value = if let Some(v) = value { + format!("{v:?}") + } else { + "not set".to_string() + }; + + println!("\t{index}: {node} = {value}"); + } + } + } + + // ************************ Compilation functions ************************** + + /// Compiles an R1CS constraint system into an arithmetic circuit. The + /// R1CS constraint system is parsed row by row, and each row is compiled + /// into a sequence of addition and multiplication gates. The output of the + /// function is a tuple containing the compiled circuit and a vector of + /// indices corresponding to the output nodes. Given a valid assignment to + /// the variables, the values of the output nodes will be one iff the + /// constraint is satisfied. + pub fn from_constraint_system(cs: &ConstraintSystem) -> (Self, Vec) { + let ConstraintMatrices { a, b, c, .. } = cs.to_matrices().unwrap(); + + let mut circuit = ArithmeticCircuit::new(); + let one = circuit.constant(F::ONE); + circuit.new_variables(cs.num_instance_variables + cs.num_witness_variables - 1); + + let mut row_expressions = |matrix: Vec>| { + matrix + .into_iter() + .map(|row| circuit.compile_sparse_scalar_product(row)) + .collect::>() + }; + + // Az, Bz, Cz + let a = row_expressions(a); + let b = row_expressions(b); + let c = row_expressions(c); + + // Az (hadamard) Bz + let pairwise_mul_a_b = a + .into_iter() + .zip(b) + .map(|(a, b)| circuit.mul(a, b)) + .collect::>(); + + let minus_one = circuit.constant(-F::ONE); + let minus_c = c + .into_iter() + .map(|c| circuit.mul(c, minus_one)) + .collect::>(); + + // Az * Bz - Cz + 1 + let outputs = pairwise_mul_a_b + .into_iter() + .zip(minus_c) + .map(|(ab, m_c)| circuit.add_nodes([ab, m_c, one])) + .collect::>(); + + (circuit, outputs) + } + + // Compile a sparse scalar product into nodes. Relies on some assumptions + // guaranteed by `from_constraint_systems`, which should be the only caller. + // Performs certain optimisations, most notably: terms of the form C * 1 and + // 1 * V are simplified to C and V respectively. + fn compile_sparse_scalar_product(&mut self, sparse_row: Vec<(F, usize)>) -> usize { + let constants = sparse_row + .into_iter() + .map(|(c, var_index)| (self.constant(c), var_index)) + .collect::>(); + + let products = constants + .into_iter() + .map(|(c_index, var_index)| { + // If either the constant or the variable is ONE, we can just return the other + if c_index == 0 || var_index == 0 { + c_index + var_index + } else { + self.mul(c_index, var_index) + } + }) + .collect::>(); + + self.add_nodes(products) + } +} + +impl Display for Node { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self { + Node::Variable(label) => write!(f, "{}", label), + Node::Constant(c) => write!(f, "Constant({})", c), + Node::Add(left, right) => write!(f, "node({}) + node({})", left, right), + Node::Mul(left, right) => write!(f, "node({}) * node({})", left, right), + } + } +} + +impl Display for ArithmeticCircuit { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + writeln!(f, "Arithmetic circuit with {} nodes:", self.num_nodes())?; + + for (index, node) in self.nodes.iter().enumerate() { + writeln!(f, "\t{}: {}", index, node)?; + } + Ok(()) + } +} + +// Discards duplicated constants and updates all gate relations accordingly +pub(crate) fn filter_constants( + nodes: &Vec>, +) -> (Vec>, BTreeMap) { + // Map of unique constants mapping sending value to final position + let mut constants = BTreeMap::new(); + + // Mapping from original indices to post-constant-removal indices + let mut filtered_indices = BTreeMap::new(); + + let mut removed_constants = 0; + + nodes.iter().enumerate().for_each(|(i, node)| match node { + Node::Constant(v) => { + if constants.contains_key(v) { + removed_constants += 1; + } else { + constants.insert(*v, i - removed_constants); + filtered_indices.insert(i, i - removed_constants); + } + }, + _ => { + filtered_indices.insert(i, i - removed_constants); + }, + }); + + // TODO possibly change to into_iter and avoid node cloning if the + // borrow checker can find it in its heart to accept that + let new_nodes = nodes + .iter() + .enumerate() + .filter_map(|(i, node)| { + match node { + Node::Constant(_) => { + // Checking if this is the first appearance of the constant + if filtered_indices.contains_key(&i) { + Some(node.clone()) + } else { + None + } + }, + Node::Variable(_) => Some(node.clone()), + Node::Add(left, right) | Node::Mul(left, right) => { + let updated_left = match nodes[*left] { + Node::Constant(c) => *constants.get(&c).unwrap(), + _ => *filtered_indices.get(left).unwrap(), + }; + let updated_right = match nodes[*right] { + Node::Constant(c) => *constants.get(&c).unwrap(), + _ => *filtered_indices.get(right).unwrap(), + }; + match node { + Node::Add(..) => Some(Node::Add(updated_left, updated_right)), + Node::Mul(..) => Some(Node::Mul(updated_left, updated_right)), + _ => unreachable!(), + } + }, + } + }) + .collect(); + + (new_nodes, constants) +} diff --git a/relations/src/arithmetic_circuit/tests.rs b/relations/src/arithmetic_circuit/tests.rs new file mode 100644 index 000000000..1ab2b78a4 --- /dev/null +++ b/relations/src/arithmetic_circuit/tests.rs @@ -0,0 +1,376 @@ +use ark_ec::short_weierstrass::Affine; +use ark_ff::{Field, UniformRand}; +use ark_std::{string::ToString, test_rng, vec::Vec}; + +use crate::{ + arithmetic_circuit::{filter_constants, ArithmeticCircuit, Node}, + reader::read_constraint_system, + TEST_DATA_PATH, +}; + +use ark_bls12_377::{Fq as FqBLS, G1Affine}; +use ark_bn254::Fr as FrBN; + +/// Generates the arithmetic circuit for the BLS12-377 elliptic curve. +/// +/// The curve is defined by the equation: +/// 1 + (1 + x^3 - y^2) = 1 +pub fn generate_bls12_377_circuit() -> ArithmeticCircuit { + let mut circuit = ArithmeticCircuit::new(); + + let one = circuit.constant(FqBLS::ONE); + + let x = circuit.new_variable_with_label("x"); + let y = circuit.new_variable_with_label("y"); + + let y_squared = circuit.pow(y, 2); + let minus_y_squared = circuit.minus(y_squared); + let x_cubed = circuit.pow(x, 3); + + circuit.add_nodes([x_cubed, one, minus_y_squared, one]); + circuit +} + +/// Generates the arithmetic circuit for the lemniscate curve defined by: +/// (x^2 + y^2)^2 - 120x^2 + 80y^2 + 1 = 1 +pub fn generate_lemniscate_circuit() -> ArithmeticCircuit { + let mut circuit = ArithmeticCircuit::new(); + + let one = circuit.constant(FrBN::ONE); + + let x = circuit.new_variable(); + let y = circuit.new_variable(); + + let a = circuit.constant(FrBN::from(120)); + let b = circuit.constant(FrBN::from(80)); + + let x_2 = circuit.mul(x, x); + let y_2 = circuit.mul(y, y); + + let a_x_2 = circuit.mul(a, x_2); + let b_y_2 = circuit.mul(b, y_2); + let minus_a_x_2 = circuit.minus(a_x_2); + + let x_2_plus_y_2 = circuit.add(x_2, y_2); + let b_y_2_minus_a_x_2 = circuit.add(b_y_2, minus_a_x_2); + + let x_2_plus_y_2_2 = circuit.mul(x_2_plus_y_2, x_2_plus_y_2); + + circuit.add_nodes([x_2_plus_y_2_2, b_y_2_minus_a_x_2, one]); + circuit +} + +/// Generates the arithmetic circuit for the determinant of a 3x3 matrix. +pub fn generate_3_by_3_determinant_circuit() -> ArithmeticCircuit { + let mut circuit = ArithmeticCircuit::new(); + + let one = circuit.constant(FrBN::ONE); + + let vars = circuit.new_variables(9); + let det = circuit.new_variable(); + + let aei = circuit.mul_nodes([vars[0], vars[4], vars[8]]); + let bfg = circuit.mul_nodes([vars[1], vars[5], vars[6]]); + let cdh = circuit.mul_nodes([vars[2], vars[3], vars[7]]); + + let ceg = circuit.mul_nodes([vars[2], vars[4], vars[6]]); + let bdi = circuit.mul_nodes([vars[1], vars[3], vars[8]]); + let afh = circuit.mul_nodes([vars[0], vars[5], vars[7]]); + + let sum1 = circuit.add_nodes([aei, bfg, cdh]); + let sum2 = circuit.add_nodes([ceg, bdi, afh]); + + let minus_sum2 = circuit.minus(sum2); + let minus_det = circuit.minus(det); + + circuit.add_nodes([sum1, minus_sum2, minus_det, one]); + circuit +} + +#[test] +fn test_add_constants() { + let mut circuit = ArithmeticCircuit::new(); + let one = circuit.constant(FrBN::ONE); + let two = circuit.constant(FrBN::from(2)); + circuit.add(one, two); + assert_eq!(circuit.evaluate(vec![]), FrBN::from(3)); +} + +#[test] +fn test_mul_constants() { + let mut circuit = ArithmeticCircuit::new(); + let a = circuit.constant(FrBN::from(6)); + let b = circuit.constant(FrBN::from(2)); + circuit.mul(a, b); + assert_eq!(circuit.evaluate(vec![]), FrBN::from(12)); +} + +#[test] +fn test_pow_constants() { + let mut circuit = ArithmeticCircuit::new(); + let two = circuit.constant(FrBN::from(2)); + circuit.pow(two, 5); + assert_eq!(circuit.evaluate(vec![]), FrBN::from(32)); +} + +#[test] +fn test_add_variables() { + let mut circuit = ArithmeticCircuit::new(); + let input = circuit.new_variables(2); + circuit.add(input[0], input[1]); + assert_eq!( + circuit.evaluate(vec![(input[0], FrBN::from(2)), (input[1], FrBN::from(3))]), + FrBN::from(5) + ); +} + +#[test] +fn test_mul_variables() { + let mut circuit = ArithmeticCircuit::new(); + let input = circuit.new_variables(2); + circuit.mul(input[0], input[1]); + assert_eq!( + circuit.evaluate(vec![(input[0], FrBN::from(2)), (input[1], FrBN::from(3))]), + FrBN::from(6) + ); +} + +#[test] +fn test_pow_variable() { + let mut circuit = ArithmeticCircuit::new(); + let a = circuit.new_variable(); + circuit.pow(a, 4); + assert_eq!(circuit.evaluate(vec![(a, FrBN::from(2))]), FrBN::from(16)); +} + +#[test] +fn test_indicator() { + let mut circuit = ArithmeticCircuit::new(); + let a = circuit.new_variable(); + circuit.indicator(a); + assert_eq!( + circuit.evaluate(vec![(a, FrBN::rand(&mut test_rng()))]), + FrBN::from(1) + ); +} + +#[tokio::test] +async fn test_multiplication() { + let cs = read_constraint_system::( + &format!(TEST_DATA_PATH!(), "multiplication.r1cs"), + &format!(TEST_DATA_PATH!(), "multiplication.wasm"), + ); + + let (circuit, _) = ArithmeticCircuit::::from_constraint_system(&cs); + + let (a, b, c) = (FrBN::from(6), FrBN::from(3), FrBN::from(2)); + let valid_assignment = vec![(1, a), (2, b), (3, c)]; + + assert_eq!(circuit.evaluate(valid_assignment), FrBN::ONE); +} + +#[tokio::test] +async fn test_cube_multioutput() { + let r1cs = read_constraint_system::( + &format!(TEST_DATA_PATH!(), "cube.r1cs"), + &format!(TEST_DATA_PATH!(), "cube.wasm"), + ); + + let (circuit, outputs) = ArithmeticCircuit::from_constraint_system(&r1cs); + + let mut clever_circuit = ArithmeticCircuit::new(); + let x = clever_circuit.new_variable(); + let x_cubed = clever_circuit.pow(x, 3); + let c = clever_circuit.constant(-FrBN::from(26)); + clever_circuit.add(x_cubed, c); + + let mut another_clever_circuit = ArithmeticCircuit::new(); + let a_x = another_clever_circuit.new_variable(); + let a_x_2 = another_clever_circuit.mul(a_x, a_x); + let a_x_cubed = another_clever_circuit.mul(a_x_2, a_x); + let a_c = another_clever_circuit.constant(-FrBN::from(26)); + another_clever_circuit.add(a_x_cubed, a_c); + + let mut yet_another_clever_circuit = ArithmeticCircuit::new(); + let y_a_x = yet_another_clever_circuit.new_variable(); + let y_a_x_cubed = yet_another_clever_circuit.mul_nodes([y_a_x, y_a_x, y_a_x]); + let y_a_c = yet_another_clever_circuit.constant(-FrBN::from(26)); + yet_another_clever_circuit.add(y_a_x_cubed, y_a_c); + + let evaluation_trace = circuit + .evaluation_trace_multioutput(vec![(1, FrBN::from(3)), (2, FrBN::from(9))], &outputs); + assert_eq!( + outputs + .into_iter() + .map(|output| evaluation_trace[output].unwrap()) + .collect::>(), + vec![FrBN::ONE, FrBN::ONE], + ); + + [ + &clever_circuit, + &another_clever_circuit, + &yet_another_clever_circuit, + ] + .iter() + .for_each(|circuit| assert_eq!(circuit.evaluate(vec![(0, FrBN::from(3))]), FrBN::ONE)); + + assert_eq!(clever_circuit, another_clever_circuit); + assert_eq!(clever_circuit, yet_another_clever_circuit); + + // With the indicator-based compiler, this would result in 719 gates + assert_eq!(circuit.num_nodes(), 15); + assert_eq!(clever_circuit.num_gates(), 3); +} + +#[test] +fn test_fibonacci() { + let mut circ = ArithmeticCircuit::::new(); + + let f_0 = circ.new_variable(); + let f_1 = circ.new_variable(); + + let mut first_operand = f_0; + let mut second_operand = f_1; + + for _ in 3..50 { + let next = circ.add(first_operand, second_operand); + first_operand = second_operand; + second_operand = next; + } + + let f_42 = FrBN::from(267914296); + + // Checking F_42 + assert_eq!( + circ.evaluate_node(vec![(f_0, FrBN::ONE), (f_1, FrBN::ONE)], 42 - 1), + f_42, + ); + + // Checking F_42 after shifting the entire sequence by 4 positions + assert_eq!( + circ.evaluate_node(vec![(f_0, FrBN::from(5)), (f_1, FrBN::from(8))], 42 - 5), + f_42, + ); +} + +#[test] +fn test_fibonacci_with_const() { + let mut circ = ArithmeticCircuit::::new(); + + let f_0 = circ.constant(FrBN::ONE); + let f_1 = circ.new_variable(); + + let mut first_operand = f_0; + let mut second_operand = f_1; + + for _ in 3..50 { + let next = circ.add(first_operand, second_operand); + first_operand = second_operand; + second_operand = next; + } + + let f_42 = FrBN::from(267914296); + + // Checking F_42 + assert_eq!(circ.evaluate_node(vec![(f_1, FrBN::ONE)], 42 - 1), f_42); +} + +#[test] +fn test_bls12_377_circuit() { + let circuit = generate_bls12_377_circuit(); + + let Affine { x, y, .. } = G1Affine::rand(&mut test_rng()); + + assert_eq!(y.pow([2]), x.pow([3]) + FqBLS::ONE); + + let valid_assignment = vec![(1, x), (2, y)]; + assert_eq!(circuit.evaluate(valid_assignment.clone()), FqBLS::ONE); +} + +#[test] +fn test_lemniscate_circuit() { + let circuit = generate_lemniscate_circuit(); + + let x = FrBN::from(8); + let y = FrBN::from(4); + + let valid_assignment = vec![(1, x), (2, y)]; + assert_eq!(circuit.evaluate(valid_assignment), FrBN::ONE); +} + +#[test] +fn test_generate_3_by_3_determinant_circuit() { + let circuit = generate_3_by_3_determinant_circuit(); + + let vars = (1..=9) + .map(|i| (i, FrBN::from(i as u64))) + .collect::>(); + let det = FrBN::from(0); + let valid_assignment = [vars, vec![(10, det)]].concat(); + + assert_eq!(circuit.evaluate(valid_assignment), FrBN::ONE); + + let circuit = generate_3_by_3_determinant_circuit(); + + let vars = vec![ + (1, FrBN::from(2)), + (2, FrBN::from(0)), + (3, FrBN::from(-1)), + (4, FrBN::from(3)), + (5, FrBN::from(5)), + (6, FrBN::from(2)), + (7, FrBN::from(-4)), + (8, FrBN::from(1)), + (9, FrBN::from(4)), + ]; + let det = FrBN::from(13); + let valid_assignment = [vars, vec![(10, det)]].concat(); + + assert_eq!(circuit.evaluate(valid_assignment), FrBN::ONE); +} + +#[test] +pub fn test_constant_filtering() { + let nodes: Vec> = vec![ + Node::Variable("x".to_string()), // 0 -> 0 + Node::Constant(FqBLS::from(3)), // 1 -> 1 + Node::Constant(FqBLS::from(3)), // 2 ---- + Node::Variable("y".to_string()), // 3 -> 2 + Node::Mul(18, 2), // 4 -> 3 + Node::Constant(-FqBLS::from(1)), // 5 -> 4 + Node::Mul(4, 1), // 6 -> 5 + Node::Mul(2, 2), // 7 -> 6 + Node::Constant(FqBLS::from(4)), // 8 -> 7 + Node::Mul(7, 7), // 9 -> 8 + Node::Constant(-FqBLS::from(1)), // 10 ---- + Node::Add(8, 5), // 11 -> 9 + Node::Add(8, 14), // 12 -> 10 + Node::Mul(17, 10), // 13 -> 11 + Node::Constant(FqBLS::from(3)), // 14 ----- + Node::Constant(-FqBLS::from(2)), // 15 -> 12 + Node::Variable("z".to_string()), // 16 -> 13 + Node::Constant(-FqBLS::from(1)), // 17 ----- + Node::Add(12, 5), // 18 -> 14 + ]; + + let filtered_nodes: Vec> = vec![ + Node::Variable("x".to_string()), // 0 -> 0 + Node::Constant(FqBLS::from(3)), // 1 -> 1 + Node::Variable("y".to_string()), // 3 -> 2 + Node::Mul(14, 1), // 4 -> 3 + Node::Constant(-FqBLS::from(1)), // 5 -> 4 + Node::Mul(3, 1), // 6 -> 5 + Node::Mul(1, 1), // 7 -> 6 + Node::Constant(FqBLS::from(4)), // 8 -> 7 + Node::Mul(6, 6), // 9 -> 8 + Node::Add(7, 4), // 11 -> 9 + Node::Add(7, 1), // 12 -> 10 + Node::Mul(4, 4), // 13 -> 11 + Node::Constant(-FqBLS::from(2)), // 15 -> 12 + Node::Variable("z".to_string()), // 16 -> 13 + Node::Add(10, 4), // 18 -> 14 + ]; + + assert_eq!(filter_constants(&nodes).0, filtered_nodes); +} diff --git a/relations/src/expression/mod.rs b/relations/src/expression/mod.rs new file mode 100644 index 000000000..c3f50e9f7 --- /dev/null +++ b/relations/src/expression/mod.rs @@ -0,0 +1,317 @@ +//! This module contains the core functionality for arithmetic expressions. +use crate::arithmetic_circuit::{filter_constants, ArithmeticCircuit, Node}; +use ark_ff::PrimeField; +use ark_std::{ + clone::Clone, + collections::BTreeMap, + convert::{AsRef, From}, + fmt::{Display, Formatter, Result}, + iter::{FromIterator, IntoIterator, Iterator, Product, Sum}, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, + option::Option::{None, Some}, + rc::Rc, + string::{String, ToString}, + vec::Vec, +}; +use itertools::Itertools; + +#[cfg(test)] +mod tests; + +/// Utilities that expose a user-friendly way to construct arithmetic circuits, +/// with syntax along the lines of: +/// let x = Expression::Variable("x"); +/// let y = Expression::Variable("y"); +/// let output = y.pow(2) - x.pow(3) + 1 + +enum ExpressionInner { + Variable(String), + Constant(F), + Add(Expression, Expression), + Mul(Expression, Expression), +} + +/// Represents an arithmetic expression over a field F. An expression is a +/// combination of products and sums of constants and variables. It is +/// represented as a pointer to an `ExpressionInner`, which is a +/// reference-counted enum that can be one of the following: +/// - Variable(String): a variable with a given label. +/// - Constant(F): a constant value in the field F. +/// - Add(Expression, Expression): the sum of two expressions. +/// - Mul(Expression, Expression): the product of two expressions. +/// +/// Expressions expose a user-friendly way to construct arithmetic circuits, +/// with syntax along the lines of: +/// let x = Expression::Variable("x"); +/// let y = Expression::Variable("y"); +/// let output = y.pow(2) - x.pow(3) + 1 +/// +/// Syntax summary: +/// - Expression::variable(id) creates a variable with the given ID. +/// +/// - Expression::constant(value) creates a constant with the given F value. +/// +/// - +, - and * are overloaded to mean addition, subtraction a nd +/// multiplication of expressions Their assigning counterparts +=, -=, *= are +/// also overloaded. +/// +/// - Constants in the form of F can be used as operands on the right-hand side +/// only. This is due to the implementation for i32 from the next point. E.g.: +/// F::from(3) * exp, F::ONE * exp, and exp * F::from(3) are all valid +/// However, 3 * exp, -5 * exp, and exp * 3 are not. +/// +/// - Constants in the form of i32 (where F: From) can be used as operands +/// on the left-hand side only. This is due to i32 and PrimeField both being +/// foreign types. E.g. 1 + exp and -5 * exp are both valid, equivalent to +/// F::from(1) + exp and F::from(-5) * exp, respectively. However, exp + 1, +/// exp - 3 and exp * -5 are not. +pub struct Expression(Rc>); + +impl Expression { + /// Creates an expression representing a constant value. + pub fn constant(value: F) -> Self { + Expression(Rc::new(ExpressionInner::Constant(value))) + } + + /// Creates an expression representing a variable with a given label. + /// The label must be unique. + pub fn variable(label: &str) -> Self { + Expression(Rc::new(ExpressionInner::Variable(label.to_string()))) + } + + /// Converts the expression into an `ArithmeticCircuit`. + pub fn to_arithmetic_circuit(&self) -> ArithmeticCircuit { + let mut nodes = BTreeMap::new(); + self.update_map(&mut nodes); + + let ptr_to_idx = nodes + .iter() + .map(|(ptr, (idx, _))| (*ptr, nodes.len() - idx - 1)) + .collect::>(); + + let sorted_nodes = nodes + .into_iter() + .sorted_by(|(_, (i, _)), (_, (j, _))| j.cmp(i)) + .map(|(_, (_, node))| node) + .collect::>(); + + let mut nodes = Vec::new(); + for node in sorted_nodes { + match node { + Node::Variable(label) => { + nodes.push(Node::Variable(label)); + }, + Node::Constant(value) => { + nodes.push(Node::Constant(value)); + }, + Node::Add(a, b) => { + nodes.push(Node::Add(ptr_to_idx[&a], ptr_to_idx[&b])); + }, + Node::Mul(a, b) => { + nodes.push(Node::Mul(ptr_to_idx[&a], ptr_to_idx[&b])); + }, + } + } + + let (nodes, constants) = filter_constants(&nodes); + + let variables = BTreeMap::from_iter(nodes.iter().enumerate().filter_map(|(i, node)| { + if let Node::Variable(label) = node { + Some((label.clone(), i)) + } else { + None + } + })); + + ArithmeticCircuit { + nodes, + constants, + variables, + unit_group_bits: None, + } + } + + fn pointer(&self) -> usize { + self.0.as_ref() as *const _ as usize + } + + fn update_map(&self, nodes: &mut BTreeMap)>) { + if nodes.contains_key(&self.pointer()) { + return; + } + match &*self.0 { + ExpressionInner::Variable(label) => { + nodes.insert(self.pointer(), (nodes.len(), Node::Variable(label.clone()))); + }, + ExpressionInner::Constant(value) => { + nodes.insert(self.pointer(), (nodes.len(), Node::Constant(*value))); + }, + ExpressionInner::Add(a, b) => { + nodes.insert( + self.pointer(), + (nodes.len(), Node::Add(a.pointer(), b.pointer())), + ); + a.update_map(nodes); + b.update_map(nodes); + }, + ExpressionInner::Mul(a, b) => { + nodes.insert( + self.pointer(), + (nodes.len(), Node::Mul(a.pointer(), b.pointer())), + ); + a.update_map(nodes); + b.update_map(nodes); + }, + } + } + + /// Computes the scalar product of two vectors of expressions. + pub fn scalar_product(a: Vec>, b: Vec>) -> Expression { + a.into_iter().zip(b).map(|(a, b)| a * b).sum() + } + + /// Computes the scalar product of a sparse vector of field elements and a + /// vector of expressions. + pub fn sparse_scalar_product(a: &Vec<(F, usize)>, b: &Vec>) -> Expression { + a.iter() + .map(|(a, i)| b[*i].clone() * *a) + .collect::>() + .into_iter() + .sum() + } + + /// Computes the power of an expression using a square-and-multiply + /// strategy. + pub fn pow(self, rhs: usize) -> Self { + if rhs == 0 { + return self; + } + + let mut bits = (0..usize::BITS).rev().map(|pos| (rhs >> pos) & 1); + + bits.position(|bit| bit == 1); + + let mut current = self.clone(); + + for bit in bits { + current = current.clone() * current; + + if bit == 1 { + current = current.clone() * self.clone(); + } + } + + current + } +} + +impl Clone for Expression { + fn clone(&self) -> Self { + Expression(Rc::clone(&self.0)) + } +} + +impl Neg for Expression { + type Output = Expression; + + fn neg(self) -> Self::Output { + Expression::constant(-F::ONE) * self + } +} + +impl Add for Expression { + type Output = Expression; + + fn add(self, rhs: Expression) -> Self::Output { + Expression(Rc::new(ExpressionInner::Add(self.clone(), rhs.clone()))) + } +} + +impl Mul for Expression { + type Output = Expression; + + fn mul(self, rhs: Self) -> Self::Output { + Expression(Rc::new(ExpressionInner::Mul(self.clone(), rhs.clone()))) + } +} + +impl Sub for Expression { + type Output = Expression; + + fn sub(self, rhs: Self) -> Self::Output { + self + (-rhs) + } +} + +macro_rules! impl_constant_op { + ($op_trait:ident, $method:ident, $op:tt) => { + impl> $op_trait> for i32 { + type Output = Expression; + + fn $method(self, rhs: Expression) -> Self::Output { + Expression::constant(F::from(self)) $op rhs + } + } + + impl $op_trait for Expression { + type Output = Expression; + + fn $method(self, rhs: F) -> Self::Output { + self $op Expression::constant(rhs) + } + } + }; +} + +impl_constant_op!(Add, add, +); +impl_constant_op!(Mul, mul, *); +impl_constant_op!(Sub, sub, -); + +macro_rules! impl_op_assign_aux { + ($op_trait_assign:ident, $method_assign:ident, $op:tt, $self_type:ty) => { + impl $op_trait_assign<$self_type> for Expression { + fn $method_assign(&mut self, rhs: $self_type) { + *self = self.clone() $op rhs; + } + } + }; +} + +macro_rules! impl_op_assign { + ($op_trait_assign:ident, $method_assign:ident, $op:tt) => { + impl_op_assign_aux!($op_trait_assign, $method_assign, $op, F); + impl_op_assign_aux!($op_trait_assign, $method_assign, $op, Expression); + }; +} + +impl_op_assign!(AddAssign, add_assign, +); +impl_op_assign!(MulAssign, mul_assign, *); +impl_op_assign!(SubAssign, sub_assign, -); + +impl Sum for Expression { + fn sum>(iter: I) -> Self { + iter.reduce(|a, b| a + b).unwrap() + } +} + +impl Product for Expression { + fn product>(iter: I) -> Self { + iter.reduce(|a, b| a * b).unwrap() + } +} + +impl Display for Expression { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let hash = self.pointer(); + match &*self.0 { + ExpressionInner::Variable(label) => write!(f, "Variable({})<{}>", label, hash), + ExpressionInner::Constant(value) => write!(f, "Constant({:?})<{}>", value, hash), + ExpressionInner::Add(a, b) => { + write!(f, "Add({}, {})<{}>", a.pointer(), b.pointer(), hash) + }, + ExpressionInner::Mul(a, b) => { + write!(f, "Mul({}, {})<{}>", a.pointer(), b.pointer(), hash) + }, + } + } +} diff --git a/relations/src/expression/tests.rs b/relations/src/expression/tests.rs new file mode 100644 index 000000000..2b075c573 --- /dev/null +++ b/relations/src/expression/tests.rs @@ -0,0 +1,387 @@ +use ark_bls12_377::{Fq, G1Affine}; +use ark_bn254::Fr; +use ark_ec::short_weierstrass::Affine; +use ark_ff::{Field, UniformRand}; +use ark_std::{collections::BTreeMap, ops::Deref, string::ToString, test_rng, vec::Vec}; +use itertools::Itertools; + +use crate::arithmetic_circuit::Node; + +use super::{Expression, ExpressionInner}; + +/// Generates the expression for the BLS12-377 elliptic curve. +/// +/// The curve is defined by the equation: +/// 1 + (1 + x^3 - y^2) = 1 +fn generate_bls12_377_expression() -> Expression { + let x = Expression::variable("x"); + let y = Expression::variable("y"); + + 1 + (1 + x.pow(3) - y.pow(2)) +} + +/// Generates the expression for the lemniscate curve defined by: +/// (x^2 + y^2)^2 - 120x^2 + 80y^2 + 1 = 1 +fn generate_lemniscate_expression() -> Expression { + let x = Expression::variable("x"); + let y = Expression::variable("y"); + + 1 + (x.clone().pow(2) + y.clone().pow(2)).pow(2) - 120 * x.pow(2) + 80 * y.pow(2) +} + +/// Generates the expression for the determinant of a 3x3 matrix. +fn generate_3_by_3_determinant_expression() -> Expression { + let matrix = (0..3) + .map(|i| { + (0..3) + .map(|j| Expression::variable(&format!("x_{i}_{j}"))) + .collect::>() + }) + .collect::>(); + + let possitive_diagonal = (0..3) + .map(|k| { + vec![0, 4, 8] + .into_iter() + .zip(0..3) + .map(|(j, i)| matrix[i][(j + k) % 3].clone()) + .product() + }) + .sum::>(); + + let negative_diagonal = (0..3) + .map(|k| { + vec![2, 4, 6] + .into_iter() + .zip(0..3) + .map(|(j, i)| matrix[i][(j + k) % 3].clone()) + .product() + }) + .sum::>(); + + let det = Expression::variable("det"); + + 1 + (possitive_diagonal - negative_diagonal - det) +} + +#[test] +fn test_get_variables() { + let circuit = generate_bls12_377_expression().to_arithmetic_circuit(); + + assert_eq!(circuit.get_variable("x"), 4); + assert_eq!(circuit.get_variable("y"), 0); + + let circuit = generate_lemniscate_expression().to_arithmetic_circuit(); + + assert_eq!(circuit.get_variable("x"), 10); + assert_eq!(circuit.get_variable("y"), 8); +} + +#[test] +fn test_same_reference() { + let mut f1 = Expression::::variable("x"); + let mut f2 = Expression::::variable("y"); + + let original_f1 = f1.clone(); + + for _ in 0..10 { + let next = f1.clone() + f2.clone(); + f1 = f2; + f2 = next; + } + + let curious_result = f2 * original_f1.clone(); + + match curious_result.0.deref() { + ExpressionInner::Mul(_, right) => { + assert_eq!(right.pointer(), original_f1.pointer()) + }, + _ => panic!("Expected a multiplication expression"), + } +} + +#[test] +fn test_addition() { + let a = Expression::variable("x"); + let b = Expression::variable("y"); + + let expression = a + b; + let circuit = expression.to_arithmetic_circuit(); + + assert_eq!( + circuit.evaluate_with_labels(vec![("x", Fr::from(3)), ("y", Fr::from(5))]), + Fr::from(8), + ); +} + +#[test] +fn test_multiplication() { + let a = Expression::variable("x"); + let b = Expression::variable("y"); + + let expression = a * b; + let circuit = expression.to_arithmetic_circuit(); + + assert_eq!( + circuit.evaluate_with_labels(vec![("x", Fr::from(3)), ("y", Fr::from(5))]), + Fr::from(15), + ); +} + +#[test] +fn test_subtraction() { + let a = Expression::variable("x"); + let b = Expression::variable("y"); + + let expression = a - b; + let circuit = expression.to_arithmetic_circuit(); + + assert_eq!( + circuit.evaluate_with_labels(vec![("x", Fr::from(3)), ("y", Fr::from(5))]), + Fr::from(-2), + ); +} + +#[test] +fn test_some_operations() { + let x_f = Fr::from(5); + let y_f = Fr::from(3); + + let output = x_f.pow([3]) + (y_f - Fr::ONE).pow([11]) + Fr::from(13); + + let x_exp = Expression::constant(x_f); + let y_exp = Expression::constant(y_f); + + let output_exp = 13 + x_exp.pow(3) + (y_exp - Fr::ONE).pow(11); + let circ_output = output_exp.to_arithmetic_circuit().evaluate(vec![]); + + assert_eq!(output, circ_output); +} + +// Add +// / \ +// Add Mul +// / \ / +// \ Constant(3) Mul Add +// Add / \ / \ +// / \ Constant(2) Mul Constant(3) +// Mul Constant(1) Mul / \ +// / \ / \ | | Constant(2) | Constant(2) Variable(1) +// | | | | +// ------------------------------------------------- +// | | | | +// ------------------------------------------------------------------------------------- +// | | +// Variable(0) Variable(1) +// Original: +// +// 16: Add(15, 5) +// 15: Add(14, 13) +// 14: Constant(3) +// 13: Mul(12, 11) +// 12: Constant(2) +// 11: Mul(10, 9) +// 10: Variable(0) +// 9: Variable(1) +// 8: Mul(4, 10) +// 7: Add(6, 8) +// 6: Constant(3) +// 5: Mul(3, 7) +// 4: Constant(2) +// 3: Add(2, 1) +// 2: Constant(1) +// 1: Mul(0, 9) +// 0: Constant(2) +// +// After filtering: +// +// 13: Add(12, 7) -> (3 + 2 * x * y) + (3 + 2 * x) * (1 + 2 * y) = 60 +// 12: Add(5, 11) -> 3 + 2 * x * y = 15 +// 11: Mul(0, 10) -> 2 * x * y = 12 +// 10: Mul(9, 8) -> x * y = 6 +// 9: Variable(0) -> x = 3 +// 8: Variable(1) -> y = 2 +// 7: Mul(6, 3) -> (3 + 2 * x) * (1 + 2 * y) = 45 +// 6: Add(5, 4) -> 3 + 2 * x = 9 +// 5: Constant(3) -> 3 +// 4: Mul(0, 9) -> 2 * x = 6 +// 3: Add(2, 1) -> 1 + 2 * y = 5 +// 2: Constant(1) -> 1 +// 1: Mul(0, 8) -> 2 * y = 4 +// 0: Constant(2) -> 2 +// +#[test] +fn test_to_arithmetic_circuit_1() { + let x = Expression::variable("x"); + let y = Expression::variable("y"); + + let expression = (3 + 2 * (x.clone() * y.clone())) + ((3 + 2 * x) * (1 + 2 * y)); + + let circuit = expression.to_arithmetic_circuit(); + + assert_eq!( + circuit.nodes, + vec![ + Node::Add(12, 7), + Node::Add(5, 11), + Node::Mul(0, 10), + Node::Mul(9, 8), + Node::Variable("x".to_string()), + Node::Variable("y".to_string()), + Node::Mul(6, 3), + Node::Add(5, 4), + Node::Constant(Fr::from(3)), + Node::Mul(0, 9), + Node::Add(2, 1), + Node::Constant(Fr::ONE), + Node::Mul(0, 8), + Node::Constant(Fr::from(2)), + ] + .iter() + .rev() + .cloned() + .collect::>() + ); + + assert_eq!( + circuit.constants, + [(Fr::from(3), 5), (Fr::ONE, 2), (Fr::from(2), 0)] + .iter() + .cloned() + .collect::>() + ); + + assert_eq!( + circuit.evaluation_trace_with_labels(vec![("x", Fr::from(3)), ("y", Fr::from(2))], 13), + vec![ + Some(Fr::from(60)), + Some(Fr::from(15)), + Some(Fr::from(12)), + Some(Fr::from(6)), + Some(Fr::from(3)), + Some(Fr::from(2)), + Some(Fr::from(45)), + Some(Fr::from(9)), + Some(Fr::from(3)), + Some(Fr::from(6)), + Some(Fr::from(5)), + Some(Fr::from(1)), + Some(Fr::from(4)), + Some(Fr::from(2)), + ] + .iter() + .rev() + .cloned() + .collect::>() + ); +} + +// Mul() +// / \ +// Add() Add() +// / \ / \ +// | | Variable(c) Mul() +// | | / \ +// ------------------------------- | +// | | | +// | ------------------------------ +// | | +// Variable(a) Variable(b) +// +// 6: Mul(5, 2) -> 35 +// 5: Add(4, 3) -> 5 +// 4: Variable(a) -> 3 +// 3: Variable(b) -> 2 +// 2: Add(1, 0) -> 7 +// 1: Variable(c) -> 1 +// 0: Mul(4, 3) -> 6 +// + +#[test] +fn test_to_arithmetic_circuit_2() { + let a = Expression::variable("a"); + let b = Expression::variable("b"); + let c = Expression::variable("c"); + + let expression = (a.clone() + b.clone()) * (c + a * b); + + let circuit = expression.to_arithmetic_circuit(); + + assert_eq!( + circuit.nodes, + vec![ + Node::Mul(5, 2), + Node::Add(4, 3), + Node::Variable("a".to_string()), + Node::Variable("b".to_string()), + Node::Add(1, 0), + Node::Variable("c".to_string()), + Node::Mul(4, 3), + ] + .iter() + .rev() + .cloned() + .collect::>() + ); + + assert_eq!(circuit.constants, BTreeMap::new()); + + assert_eq!( + circuit.evaluation_trace_with_labels( + vec![("a", Fr::from(3)), ("b", Fr::from(2)), ("c", Fr::from(1))], + 6 + ), + vec![ + Some(Fr::from(6)), + Some(Fr::from(1)), + Some(Fr::from(7)), + Some(Fr::from(2)), + Some(Fr::from(3)), + Some(Fr::from(5)), + Some(Fr::from(35)) + ] + ); +} + +#[test] +fn test_to_arithmetic_circuit_3() { + let circuit = generate_3_by_3_determinant_expression().to_arithmetic_circuit(); + + let values = (0..3) + .cartesian_product(0..3) + .map(|(i, j)| { + ( + format!("x_{}_{}", i, j), + Fr::from((3 * i + j) * (3 * i + j)), + ) + }) + .collect::>(); + + let vars = [values, vec![("det".to_string(), Fr::from(-216))]].concat(); + + assert_eq!( + circuit.evaluate_with_labels(vars.iter().map(|(k, v)| (k.as_str(), *v)).collect()), + Fr::ONE, + ); +} + +#[test] +fn test_to_arithmetic_circuit_4() { + let circuit = generate_bls12_377_expression().to_arithmetic_circuit(); + let Affine { x, y, .. } = G1Affine::rand(&mut test_rng()); + + assert_eq!( + circuit.evaluate_with_labels(vec![("x", x), ("y", y)]), + Fq::ONE + ); +} + +#[test] +fn test_to_arithmetic_circuit_5() { + let circuit = generate_lemniscate_expression().to_arithmetic_circuit(); + + assert_eq!( + circuit.evaluate_with_labels(vec![("x", Fr::from(8)), ("y", Fr::from(4))]), + Fr::ONE + ); +} diff --git a/relations/src/lib.rs b/relations/src/lib.rs index cf249f777..40a7125d7 100644 --- a/relations/src/lib.rs +++ b/relations/src/lib.rs @@ -1,6 +1,6 @@ //! Core interface for working with various relations that are useful in -//! zkSNARKs. At the moment, we only implement APIs for working with Rank-1 -//! Constraint Systems (R1CS). +//! zkSNARKs. At the moment, we implement APIs for working with Rank-1 +//! Constraint Systems (R1CS), Arithmetic Circuits and Arithmetic Expressions. #![cfg_attr(not(feature = "std"), no_std)] #![warn( @@ -15,4 +15,17 @@ #[macro_use] extern crate ark_std; +pub mod arithmetic_circuit; +pub mod expression; pub mod r1cs; + +#[cfg(test)] +pub(crate) mod reader; + +/// The path to the test data directory that contains the circom test files. +#[macro_export] +macro_rules! TEST_DATA_PATH { + () => { + concat!(env!("CARGO_MANIFEST_DIR"), "/../circom/{}",) + }; +} diff --git a/relations/src/r1cs/constraint_system.rs b/relations/src/r1cs/constraint_system.rs index 1b0d54922..cb92564a3 100644 --- a/relations/src/r1cs/constraint_system.rs +++ b/relations/src/r1cs/constraint_system.rs @@ -176,7 +176,8 @@ impl ConstraintSystem { /// Specify whether this constraint system should aim to optimize weight, /// number of constraints, or neither. pub fn set_optimization_goal(&mut self, goal: OptimizationGoal) { - // `set_optimization_goal` should only be executed before any constraint or value is created. + // `set_optimization_goal` should only be executed before any constraint or + // value is created. assert_eq!(self.num_instance_variables, 1); assert_eq!(self.num_witness_variables, 0); assert_eq!(self.num_constraints, 0); @@ -297,15 +298,17 @@ impl ConstraintSystem { /// Transform the map of linear combinations. /// Specifically, allow the creation of additional witness assignments. /// - /// This method is used as a subroutine of `inline_all_lcs` and `outline_lcs`. + /// This method is used as a subroutine of `inline_all_lcs` and + /// `outline_lcs`. /// - /// The transformer function is given a references of this constraint system (&self), - /// number of times used, and a mutable reference of the linear combination to be transformed. - /// (&ConstraintSystem, usize, &mut LinearCombination) + /// The transformer function is given a references of this constraint system + /// (&self), number of times used, and a mutable reference of the linear + /// combination to be transformed. (&ConstraintSystem, usize, + /// &mut LinearCombination) /// - /// The transformer function returns the number of new witness variables needed - /// and a vector of new witness assignments (if not in the setup mode). - /// (usize, Option>) + /// The transformer function returns the number of new witness variables + /// needed and a vector of new witness assignments (if not in the setup + /// mode). (usize, Option>) pub fn transform_lc_map( &mut self, transformer: &mut dyn FnMut( @@ -1103,7 +1106,8 @@ mod tests { // There will be six variables in the system, in the order governed by adding // them to the constraint system (Note that the CS is initialised with // `Variable::One` in the first position implicitly). - // Note also that the all public variables will always be placed before all witnesses + // Note also that the all public variables will always be placed before all + // witnesses // // Variable::One // Variable::Instance(35) @@ -1134,7 +1138,8 @@ mod tests { // There are four gates(constraints), each generating a row. // Resulting matrices: // (Note how 2nd & 3rd columns are swapped compared to the online example. - // This results from an implementation detail of placing all Variable::Instances(_) first. + // This results from an implementation detail of placing all + // Variable::Instances(_) first. // // A // [0, 0, 1, 0, 0, 0] diff --git a/relations/src/reader.rs b/relations/src/reader.rs new file mode 100644 index 000000000..ec9bd261e --- /dev/null +++ b/relations/src/reader.rs @@ -0,0 +1,89 @@ +use crate::r1cs::{ + ConstraintSystem, ConstraintSystemRef, LinearCombination, SynthesisError, Variable, +}; +use ark_circom::{CircomBuilder, CircomCircuit, CircomConfig}; +use ark_ff::PrimeField; +use ark_std::path::Path; + +pub fn read_constraint_system( + r1cs_file: impl AsRef, + wasm_file: impl AsRef, +) -> ConstraintSystem { + // Load the WASM and R1CS for witness and proof generation + let cfg = CircomConfig::::new(wasm_file, r1cs_file).unwrap(); + + let builder = CircomBuilder::new(cfg); + let circom = builder.setup(); + + let cs = ConstraintSystem::::new_ref(); + + // TODO: replace with `circom.generate_constraints(cs.clone())` once the + // dependency issue is fixed. + generate_constraints(circom, cs.clone()).unwrap(); + + cs.into_inner().unwrap() +} + +// TODO: currently, CircomCircuit::generate_constraints() cannot be used due to +// inconsistent imports. This branch of `ark-relations` has +// `circom-compat/release-0.5` as a dev-dependency, and +// `circom-compat/release-0.5` has `ark-relations v0.5.0` as a dependency, which +// currently does not exist, therefore +// `ark-circom::ark_relations::r1cs::ConstraintSystem` is not the same structure +// as `crate::r1cs::ConstraintSystem`. +fn generate_constraints( + circom: CircomCircuit, + cs: ConstraintSystemRef, +) -> Result<(), SynthesisError> { + let witness = &circom.witness; + let wire_mapping = &circom.r1cs.wire_mapping; + + // Start from 1 because Arkworks implicitly allocates One for the first input + for i in 1..circom.r1cs.num_inputs { + cs.new_input_variable(|| { + Ok(match witness { + None => F::from(1u32), + Some(w) => match wire_mapping { + Some(m) => w[m[i]], + None => w[i], + }, + }) + })?; + } + + for i in 0..circom.r1cs.num_aux { + cs.new_witness_variable(|| { + Ok(match witness { + None => F::from(1u32), + Some(w) => match wire_mapping { + Some(m) => w[m[i + circom.r1cs.num_inputs]], + None => w[i + circom.r1cs.num_inputs], + }, + }) + })?; + } + + let make_index = |index| { + if index < circom.r1cs.num_inputs { + Variable::Instance(index) + } else { + Variable::Witness(index - circom.r1cs.num_inputs) + } + }; + let make_lc = |lc_data: &[(usize, F)]| { + lc_data.iter().fold( + LinearCombination::::zero(), + |lc: LinearCombination, (index, coeff)| lc + (*coeff, make_index(*index)), + ) + }; + + for constraint in &circom.r1cs.constraints { + cs.enforce_constraint( + make_lc(&constraint.0), + make_lc(&constraint.1), + make_lc(&constraint.2), + )?; + } + + Ok(()) +} diff --git a/snark/Cargo.toml b/snark/Cargo.toml index b14a88ee6..30dbf74dd 100644 --- a/snark/Cargo.toml +++ b/snark/Cargo.toml @@ -17,3 +17,7 @@ ark-ff.workspace = true ark-std.workspace = true ark-serialize.workspace = true ark-relations = { version = "0.5", path = "../relations", default-features = false } + +[features] +default = [] +std = [ "ark-std/std", "ark-ff/std", "ark-serialize/std" ] \ No newline at end of file diff --git a/snark/src/lib.rs b/snark/src/lib.rs index 6003837ed..e989cba7d 100644 --- a/snark/src/lib.rs +++ b/snark/src/lib.rs @@ -13,8 +13,10 @@ use ark_ff::PrimeField; use ark_relations::r1cs::ConstraintSynthesizer; use ark_serialize::{CanonicalDeserialize, CanonicalSerialize}; -use ark_std::fmt::Debug; -use ark_std::rand::{CryptoRng, RngCore}; +use ark_std::{ + fmt::Debug, + rand::{CryptoRng, RngCore}, +}; /// The basic functionality for a SNARK. pub trait SNARK { @@ -52,8 +54,8 @@ pub trait SNARK { ) -> Result; /// Checks that `proof` is a valid proof of the satisfaction of circuit - /// encoded in `circuit_vk`, with respect to the public input `public_input`, - /// specified as R1CS constraints. + /// encoded in `circuit_vk`, with respect to the public input + /// `public_input`, specified as R1CS constraints. fn verify( circuit_vk: &Self::VerifyingKey, public_input: &[F], @@ -69,8 +71,8 @@ pub trait SNARK { ) -> Result; /// Checks that `proof` is a valid proof of the satisfaction of circuit - /// encoded in `circuit_pvk`, with respect to the public input `public_input`, - /// specified as R1CS constraints. + /// encoded in `circuit_pvk`, with respect to the public input + /// `public_input`, specified as R1CS constraints. fn verify_with_processed_vk( circuit_pvk: &Self::ProcessedVerifyingKey, public_input: &[F],