@@ -413,6 +413,7 @@ def add_function(name, mp_function, np_function, ocl_function,
413413 ocl_function = make_ocl ("return fmod(q, 2*M_PI);" , "sas_fmod" ),
414414)
415415
416+ # TODO: move to sas_special
416417def sas_langevin (x ):
417418 scalar = np .isscalar (x )
418419 if scalar :
@@ -451,14 +452,30 @@ def sas_langevin_x(x):
451452 mp_function = lambda x : (1 / mp .tanh (x ) - 1 / x ),
452453 np_function = sas_langevin ,
453454 #ocl_function=make_ocl("return q < 0.7 ? q*(1./3. + q*q*(-1./45. + q*q*(2./945. + q*q*(-1./4725.) + q*q*(2./93555.)))) : 1/tanh(q) - 1/q;", "sas_langevin"),
454- ocl_function = make_ocl ("return q < 1e-5 ? q/3. : 1/tanh(q) - 1/q;" , "sas_langevin" ),
455+ #ocl_function=make_ocl("return q < 1e-5 ? q/3. : 1/tanh(q) - 1/q;", "sas_langevin"),
456+ ocl_function = make_ocl ("""
457+ #if FLOAT_SIZE>4 // DOUBLE_PRECISION
458+ # define LANGEVIN_CUTOFF 0.1
459+ #else
460+ # define LANGEVIN_CUTOFF 1.0
461+ #endif
462+ const double qsq = q*q;
463+ return (q < LANGEVIN_CUTOFF) ? q / (3. + qsq / (5. + qsq/(7. + qsq/(9.)))) : 1/tanh(q) - 1/q;
464+ """ , "sas_langevin" ),
455465)
456466add_function (
457467 name = "langevin(x)/x" ,
458468 mp_function = lambda x : (1 / mp .tanh (x ) - 1 / x )/ x ,
459- #np_function=lambda x: sas_langevin(x)/x, # Note: need to test for x=0
460469 np_function = sas_langevin_x ,
461- ocl_function = make_ocl ("return q < 1e-5 ? 1./3. : (1/tanh(q) - 1/q)/q;" , "sas_langevin_x" ),
470+ ocl_function = make_ocl ("""
471+ #if FLOAT_SIZE>4 // DOUBLE_PRECISION
472+ # define LANGEVIN_CUTOFF 0.1
473+ #else
474+ # define LANGEVIN_CUTOFF 1.0
475+ #endif
476+ const double qsq = q*q;
477+ return (q < LANGEVIN_CUTOFF) ? 1. / (3. + qsq / (5. + qsq/(7. + qsq/(9.)))) : (1/tanh(q) - 1/q)/q;
478+ """ , "sas_langevin_x" ),
462479)
463480add_function (
464481 name = "gauss_coil" ,
0 commit comments