Skip to content

Commit f033f9d

Browse files
authored
Merge pull request #45 from PathwayCommons/development
Use mixed precision
2 parents b18897f + 08a64ea commit f033f9d

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

pathway_abstract_classifier/pathway_abstract_classifier.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,12 @@
33
import ktrain
44
from cached_path import cached_path
55

6+
from tensorflow.keras import mixed_precision # erroneous missing import
7+
import tensorflow as tf
8+
9+
if len(tf.config.list_physical_devices('GPU')) > 0:
10+
mixed_precision.set_global_policy('mixed_float16')
11+
612

713
class Prediction(NamedTuple):
814
document: Dict[str, str]

0 commit comments

Comments
 (0)