1- use std:: { array, borrow:: BorrowMut , cmp:: max, sync:: Arc } ;
1+ use std:: { borrow:: BorrowMut , cmp:: max, sync:: Arc } ;
2+ use crate :: { ShaDigestColsRefMut , ShaRoundColsRefMut , ShaRoundColsRef } ;
23
34use openvm_circuit:: arch:: {
45 instructions:: riscv:: RV32_CELL_BITS ,
@@ -22,8 +23,7 @@ use openvm_stark_sdk::utils::create_seeded_rng;
2223use rand:: Rng ;
2324
2425use crate :: {
25- compose, small_sig0_field, Sha256Config , Sha512Config , ShaAir , ShaConfig , ShaFlagsColsRef ,
26- ShaFlagsColsRefMut , ShaPrecomputedValues ,
26+ compose, small_sig0_field, Sha256Config , Sha512Config , ShaAir , ShaConfig , ShaPrecomputedValues ,
2727} ;
2828
2929// A wrapper AIR purely for testing purposes
@@ -101,7 +101,7 @@ fn rand_sha_test<C: ShaConfig + ShaPrecomputedValues<C::Word> + 'static>() {
101101 let bitwise_chip = SharedBitwiseOperationLookupChip :: < RV32_CELL_BITS > :: new ( bitwise_bus) ;
102102 let len = rng. gen_range ( 1 ..100 ) ;
103103 let random_records: Vec < _ > = ( 0 ..len)
104- . map ( |_ | {
104+ . map ( |i | {
105105 (
106106 ( 0 ..C :: BLOCK_U8S )
107107 . map ( |_| rng. gen :: < u8 > ( ) )
@@ -137,7 +137,7 @@ fn rand_sha512_test() {
137137pub struct ShaTestBadFinalHashChip < C : ShaConfig + ShaPrecomputedValues < C :: Word > > {
138138 pub air : ShaTestAir < C > ,
139139 pub bitwise_lookup_chip : SharedBitwiseOperationLookupChip < 8 > ,
140- pub records : Vec < ( Vec < u8 > , bool ) > , // length of inner vec is C::BLOCK_U8S
140+ pub records : Vec < ( Vec < u8 > , bool ) > , // length of inner vec should be C::BLOCK_U8S
141141}
142142
143143impl < SC : StarkGenericConfig , C : ShaConfig + ShaPrecomputedValues < C :: Word > + ' static > Chip < SC >
@@ -150,7 +150,7 @@ where
150150 }
151151
152152 fn generate_air_proof_input ( self ) -> AirProofInput < SC > {
153- let mut trace = crate :: generate_trace :: < Val < SC > > (
153+ let mut trace = crate :: generate_trace :: < Val < SC > , C > (
154154 & self . air . sub_air ,
155155 self . bitwise_lookup_chip . clone ( ) ,
156156 self . records . clone ( ) ,
@@ -161,7 +161,7 @@ where
161161 for ( i, row) in self . records . iter ( ) . enumerate ( ) {
162162 if row. 1 {
163163 let last_digest_row_idx = ( i + 1 ) * C :: ROWS_PER_BLOCK - 1 ;
164- let last_digest_row: crate :: ShaDigestColsRefMut < Val < SC > > =
164+ let mut last_digest_row: crate :: ShaDigestColsRefMut < Val < SC > > =
165165 ShaDigestColsRefMut :: from :: < C > (
166166 trace. row_mut ( last_digest_row_idx) [ ..C :: DIGEST_WIDTH ] . borrow_mut ( ) ,
167167 ) ;
@@ -176,10 +176,10 @@ where
176176 trace. row_pair_mut ( last_digest_row_idx - 1 , last_digest_row_idx) ;
177177 let last_round_row: crate :: ShaRoundColsRefMut < Val < SC > > =
178178 ShaRoundColsRefMut :: from :: < C > ( last_round_row. borrow_mut ( ) ) ;
179- let last_digest_row: crate :: ShaRoundColsRefMut < Val < SC > > =
179+ let mut last_digest_row: crate :: ShaRoundColsRefMut < Val < SC > > =
180180 ShaRoundColsRefMut :: from :: < C > ( last_digest_row. borrow_mut ( ) ) ;
181181 // fix the intermed_4 for the digest row
182- generate_intermed_4 ( last_round_row, last_digest_row) ;
182+ generate_intermed_4 :: < Val < SC > , C > ( & ShaRoundColsRef :: from_mut :: < C > ( & last_round_row) , & mut last_digest_row) ;
183183 }
184184 }
185185
@@ -199,52 +199,72 @@ where
199199
200200// Copy of private method in Sha256Air used for testing
201201/// Puts the correct intermed_4 in the `next_row`
202- fn generate_intermed_4 < F : PrimeField32 > (
203- local_cols : & Sha256RoundCols < F > ,
204- next_cols : & mut Sha256RoundCols < F > ,
202+ fn generate_intermed_4 < F : PrimeField32 , C : ShaConfig + ShaPrecomputedValues < C :: Word > > (
203+ local_cols : & ShaRoundColsRef < F > ,
204+ next_cols : & mut ShaRoundColsRefMut < F > ,
205205) {
206- let w = [ local_cols. message_schedule . w , next_cols. message_schedule . w ] . concat ( ) ;
207- let w_limbs: Vec < [ F ; SHA256_WORD_U16S ] > = w
206+ let w = [
207+ local_cols
208+ . message_schedule
209+ . w
210+ . rows ( )
211+ . into_iter ( )
212+ . collect :: < Vec < _ > > ( ) ,
213+ next_cols
214+ . message_schedule
215+ . w
216+ . rows ( )
217+ . into_iter ( )
218+ . collect :: < Vec < _ > > ( ) ,
219+ ]
220+ . concat ( ) ;
221+
222+
223+ // length of inner vec is C::WORD_U16S
224+ let w_limbs: Vec < Vec < F > > = w
208225 . iter ( )
209- . map ( |x| array:: from_fn ( |i| compose :: < F > ( & x[ i * 16 ..( i + 1 ) * 16 ] , 1 ) ) )
226+ . map ( |x| {
227+ ( 0 ..C :: WORD_U16S )
228+ . map ( |i| compose :: < F > ( & x. as_slice ( ) . unwrap ( ) [ i * 16 ..( i + 1 ) * 16 ] , 1 ) )
229+ . collect :: < Vec < F > > ( )
230+ } )
210231 . collect ( ) ;
211- for i in 0 ..SHA256_ROUNDS_PER_ROW {
212- let sig_w = small_sig0_field :: < F > ( & w[ i + 1 ] ) ;
213- let sig_w_limbs: [ F ; SHA256_WORD_U16S ] =
214- array:: from_fn ( |j| compose :: < F > ( & sig_w[ j * 16 ..( j + 1 ) * 16 ] , 1 ) ) ;
232+ for i in 0 ..C :: ROUNDS_PER_ROW {
233+ let sig_w = small_sig0_field :: < F , C > ( w[ i + 1 ] . as_slice ( ) . unwrap ( ) ) ;
234+ let sig_w_limbs: Vec < F > = ( 0 ..C :: WORD_U16S )
235+ . map ( |j| compose :: < F > ( & sig_w[ j * 16 ..( j + 1 ) * 16 ] , 1 ) )
236+ . collect ( ) ;
215237 for ( j, sig_w_limb) in sig_w_limbs. iter ( ) . enumerate ( ) {
216- next_cols. schedule_helper . intermed_4 [ i ] [ j ] = w_limbs[ i] [ j] + * sig_w_limb;
238+ next_cols. schedule_helper . intermed_4 [ [ i , j ] ] = w_limbs[ i] [ j] + * sig_w_limb;
217239 }
218240 }
219241}
220242
221- impl ChipUsageGetter for Sha256TestBadFinalHashChip {
243+ impl < C : ShaConfig + ShaPrecomputedValues < C :: Word > > ChipUsageGetter for ShaTestBadFinalHashChip < C > {
222244 fn air_name ( & self ) -> String {
223245 get_air_name ( & self . air )
224246 }
225247 fn current_trace_height ( & self ) -> usize {
226- self . records . len ( ) * SHA256_ROWS_PER_BLOCK
248+ self . records . len ( ) * C :: ROWS_PER_BLOCK
227249 }
228250
229251 fn trace_width ( & self ) -> usize {
230- max ( SHA256_ROUND_WIDTH , SHA256_DIGEST_WIDTH )
252+ max ( C :: ROUND_WIDTH , C :: DIGEST_WIDTH )
231253 }
232254}
233255
234- #[ test]
235- #[ should_panic]
236- fn test_sha256_final_hash_constraints ( ) {
256+ fn test_sha_final_hash_constraints < C : ShaConfig + ShaPrecomputedValues < C :: Word > + ' static > ( ) {
237257 let mut rng = create_seeded_rng ( ) ;
238258 let tester = VmChipTestBuilder :: default ( ) ;
239259 let bitwise_bus = BitwiseOperationLookupBus :: new ( BITWISE_OP_LOOKUP_BUS ) ;
240260 let bitwise_chip = SharedBitwiseOperationLookupChip :: < RV32_CELL_BITS > :: new ( bitwise_bus) ;
241261 let len = rng. gen_range ( 1 ..100 ) ;
242262 let random_records: Vec < _ > = ( 0 ..len)
243- . map ( |_| ( array :: from_fn ( |_| rng. gen :: < u8 > ( ) ) , true ) )
263+ . map ( |_| ( ( 0 .. C :: BLOCK_U8S ) . map ( |_| rng. gen :: < u8 > ( ) ) . collect :: < Vec < _ > > ( ) , true ) )
244264 . collect ( ) ;
245- let chip = Sha256TestBadFinalHashChip {
246- air : Sha256TestAir {
247- sub_air : Sha256Air :: new ( bitwise_bus, SELF_BUS_IDX ) ,
265+ let chip = ShaTestBadFinalHashChip {
266+ air : ShaTestAir {
267+ sub_air : ShaAir :: < C > :: new ( bitwise_bus, SELF_BUS_IDX ) ,
248268 } ,
249269 bitwise_lookup_chip : bitwise_chip. clone ( ) ,
250270 records : random_records,
@@ -253,3 +273,15 @@ fn test_sha256_final_hash_constraints() {
253273 let tester = tester. build ( ) . load ( chip) . load ( bitwise_chip) . finalize ( ) ;
254274 tester. simple_test ( ) . expect ( "Verification failed" ) ;
255275}
276+
277+ #[ test]
278+ #[ should_panic]
279+ fn test_sha256_final_hash_constraints ( ) {
280+ test_sha_final_hash_constraints :: < Sha256Config > ( ) ;
281+ }
282+
283+ #[ test]
284+ #[ should_panic]
285+ fn test_sha512_final_hash_constraints ( ) {
286+ test_sha_final_hash_constraints :: < Sha512Config > ( ) ;
287+ }
0 commit comments