Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions kernel_tuner/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,22 +58,43 @@ def __call__(self, params):

def _find_bfloat16_if_available():
# Try to get bfloat16 if available.
try:
from bfloat16 import bfloat16
return bfloat16
except ImportError:
pass
dtype = None

# get it via numpy if available
try:
from tensorflow import bfloat16
return bfloat16.as_numpy_dtype
except ImportError:
dtype = np.dtype("bfloat16")
except TypeError:
pass

logging.warning(
"could not find `bfloat16` data type for numpy, "
+ "please install either the package `bfloat16` or `tensorflow`"
)
# otherwise, try ml_dtypes
if dtype is None:
try:
from ml_dtypes import bfloat16
dtype = bfloat16
except ImportError:
pass

# otherwise, try jax
if dtype is None:
try:
from jax.numpy import bfloat16
dtype = bfloat16
except ImportError:
pass

# otherwise, try tensorflow
if dtype is None:
try:
from tensorflow import bfloat16
dtype = bfloat16.as_numpy_dtype
except ImportError:
pass

if dtype is None:
logging.warning(
"could not find `bfloat16` data type for numpy, "
+ "please install either the package `ml_dtypes`, `jax`, or `tensorflow`"
)

return None

Expand Down