Skip to content

Commit 68798a9

Browse files
committed
Add missing file
1 parent 4b0a3ca commit 68798a9

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

gpu_environment.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# coding=utf-8
2+
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import tensorflow as tf
17+
import numpy as np
18+
19+
def float32_variable_storage_getter(getter, name, shape=None, dtype=None,
20+
initializer=None, regularizer=None,
21+
trainable=True,
22+
*args, **kwargs):
23+
"""Custom variable getter that forces trainable variables to be stored in
24+
float32 precision and then casts them to the training precision.
25+
"""
26+
storage_dtype = tf.float32 if trainable else dtype
27+
variable = getter(name, shape, dtype=storage_dtype,
28+
initializer=initializer, regularizer=regularizer,
29+
trainable=trainable,
30+
*args, **kwargs)
31+
if trainable and dtype != tf.float32:
32+
variable = tf.cast(variable, dtype)
33+
return variable
34+
35+
def get_custom_getter(compute_type):
36+
return float32_variable_storage_getter if compute_type == tf.float16 else None

0 commit comments

Comments
 (0)