diff --git a/kernel_tuner/accuracy.py b/kernel_tuner/accuracy.py index 49154190..b647947c 100644 --- a/kernel_tuner/accuracy.py +++ b/kernel_tuner/accuracy.py @@ -58,22 +58,43 @@ def __call__(self, params): def _find_bfloat16_if_available(): # Try to get bfloat16 if available. - try: - from bfloat16 import bfloat16 - return bfloat16 - except ImportError: - pass + dtype = None + # get it via numpy if available try: - from tensorflow import bfloat16 - return bfloat16.as_numpy_dtype - except ImportError: + dtype = np.dtype("bfloat16") + except TypeError: pass - logging.warning( - "could not find `bfloat16` data type for numpy, " - + "please install either the package `bfloat16` or `tensorflow`" - ) + # otherwise, try ml_dtypes + if dtype is None: + try: + from ml_dtypes import bfloat16 + dtype = bfloat16 + except ImportError: + pass + + # otherwise, try jax + if dtype is None: + try: + from jax.numpy import bfloat16 + dtype = bfloat16 + except ImportError: + pass + + # otherwise, try tensorflow + if dtype is None: + try: + from tensorflow import bfloat16 + dtype = bfloat16.as_numpy_dtype + except ImportError: + pass + + if dtype is None: + logging.warning( + "could not find `bfloat16` data type for numpy, " + + "please install either the package `ml_dtypes`, `jax`, or `tensorflow`" + ) return None