@@ -43,6 +43,26 @@ use transcript::{BasicTranscript, Transcript};
43
43
use witness:: next_pow2_instance_padding;
44
44
45
45
use ceno_gpu:: gl64:: CudaHalGL64 ;
46
+ use cudarc:: driver:: { CudaDevice , DriverError } ;
47
+
48
+ use once_cell:: sync:: Lazy ;
49
+ use std:: sync:: Mutex ;
50
+ // static CUDA_HAL: Lazy<Mutex<CudaHalGL64>> = Lazy::new(|| {
51
+ // Mutex::new(CudaHalGL64::new().unwrap())
52
+ // });
53
+
54
+ static CUDA_DEVICE : Lazy < Result < Arc < CudaDevice > , DriverError > > = Lazy :: new ( || {
55
+ CudaDevice :: new ( 0 )
56
+ } ) ;
57
+ static CUDA_HAL : Lazy < Result < Arc < Mutex < CudaHalGL64 > > , Box < dyn std:: error:: Error + Send + Sync > > > = Lazy :: new ( || {
58
+ let device = CUDA_DEVICE . as_ref ( ) . map_err ( |e| format ! ( "Device init failed: {:?}" , e) ) ?;
59
+ device. bind_to_thread ( ) ?;
60
+
61
+ CudaHalGL64 :: new ( )
62
+ . map ( |hal| Arc :: new ( Mutex :: new ( hal) ) )
63
+ . map_err ( |e| Box :: new ( e) as Box < dyn std:: error:: Error + Send + Sync > )
64
+ } ) ;
65
+
46
66
47
67
pub struct GpuTowerProver ;
48
68
@@ -295,7 +315,12 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> TraceCommitter<GpuBa
295
315
// panic!("error: type conversion failed");
296
316
// };
297
317
298
- let cuda_hal = CudaHalGL64 :: new ( ) . unwrap ( ) ;
318
+ // let cuda_hal = CUDA_HAL.lock().unwrap(); // CudaHalGL64::new().unwrap();
319
+ let device = CUDA_DEVICE . as_ref ( ) . map_err ( |e| format ! ( "Device not available: {:?}" , e) ) . unwrap ( ) ;
320
+ device. bind_to_thread ( ) . unwrap ( ) ;
321
+ let hal_arc = CUDA_HAL . as_ref ( ) . map_err ( |e| format ! ( "HAL not available: {:?}" , e) ) . unwrap ( ) ;
322
+ let cuda_hal = hal_arc. lock ( ) . unwrap ( ) ;
323
+
299
324
let traces_gl64: Vec < witness:: RowMajorMatrix < p3:: goldilocks:: Goldilocks > > =
300
325
unsafe { std:: mem:: transmute ( vec_traces. clone ( ) ) } ;
301
326
let pcs_data = cuda_hal. basefold . batch_commit ( traces_gl64) . unwrap ( ) ;
@@ -863,6 +888,10 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> MainSumcheckProver<G
863
888
}
864
889
}
865
890
891
+ use p3:: field:: extension:: BinomialExtensionField ;
892
+ type GL64 = p3:: goldilocks:: Goldilocks ;
893
+ type EGL64 = BinomialExtensionField < GL64 , 2 > ;
894
+
866
895
impl < E : ExtensionField , PCS : PolynomialCommitmentScheme < E > > OpeningProver < GpuBackend < E , PCS > >
867
896
for GpuProver < GpuBackend < E , PCS > >
868
897
{
@@ -880,9 +909,14 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> OpeningProver<GpuBac
880
909
panic ! ( "GPU backend only supports Goldilocks base field" ) ;
881
910
}
882
911
883
- use p3:: field:: extension:: BinomialExtensionField ;
884
- type EGL64 = BinomialExtensionField < p3:: goldilocks:: Goldilocks , 2 > ;
885
- let cuda_hal = CudaHalGL64 :: new ( ) . unwrap ( ) ;
912
+ // use p3::field::extension::BinomialExtensionField;
913
+ // type GL64 = p3::goldilocks::Goldilocks;
914
+ // type EGL64 = BinomialExtensionField<GL64, 2>;
915
+ // let cuda_hal = CUDA_HAL.lock().unwrap(); //CudaHalGL64::new().unwrap();
916
+ let device = CUDA_DEVICE . as_ref ( ) . map_err ( |e| format ! ( "Device not available: {:?}" , e) ) . unwrap ( ) ;
917
+ device. bind_to_thread ( ) . unwrap ( ) ;
918
+ let hal_arc = CUDA_HAL . as_ref ( ) . map_err ( |e| format ! ( "HAL not available: {:?}" , e) ) . unwrap ( ) ;
919
+ let cuda_hal = hal_arc. lock ( ) . unwrap ( ) ;
886
920
887
921
let mut rounds = vec ! [ ] ;
888
922
rounds. push ( (
@@ -913,13 +947,17 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> OpeningProver<GpuBac
913
947
) ) ;
914
948
}
915
949
950
+
951
+ use ceno_gpu:: gl64:: buffer:: BufferImpl ;
952
+ use ceno_gpu:: BasefoldCommitmentWithWitness as BasefoldCommitmentWithWitnessGpu ;
953
+
916
954
// Type conversions using unsafe transmute
917
955
let pp_gl64: & mpcs:: basefold:: structure:: BasefoldProverParams < EGL64 , mpcs:: BasefoldRSParams > =
918
956
unsafe { std:: mem:: transmute ( self . pp . as_ref ( ) . unwrap ( ) ) } ;
919
957
let rounds_gl64: Vec < _ > = rounds
920
958
. iter ( )
921
959
. map ( |( commitment, point_eval_pairs) | {
922
- let commitment_gl64: & mpcs :: BasefoldCommitmentWithWitness < EGL64 > =
960
+ let commitment_gl64: & BasefoldCommitmentWithWitnessGpu < GL64 , BufferImpl < GL64 > > =
923
961
unsafe { std:: mem:: transmute ( * commitment) } ;
924
962
let point_eval_pairs_gl64: Vec < _ > = point_eval_pairs
925
963
. iter ( )
0 commit comments