Skip to content
This repository was archived by the owner on Feb 3, 2025. It is now read-only.

Commit 6128d0b

Browse files
author
DEKHTIARJonathan
committed
Future Proofing TF Device Memory Config APIs
1 parent 8e8ae15 commit 6128d0b

File tree

1 file changed

+17
-9
lines changed

1 file changed

+17
-9
lines changed

tftrt/examples/benchmark_runner.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ def __init__(self, args):
7878
self._config_gpu_memory(self._args.gpu_mem_cap)
7979

8080
def _config_gpu_memory(self, gpu_mem_cap):
81-
gpus = tf.config.experimental.list_physical_devices('GPU')
81+
try:
82+
gpus = tf.config.list_physical_devices('GPU')
83+
except AttributeError:
84+
gpus = tf.config.experimental.list_physical_devices('GPU')
8285

8386
if not gpus:
8487
raise RuntimeError("No GPUs has been found.")
@@ -90,15 +93,20 @@ def _config_gpu_memory(self, gpu_mem_cap):
9093
for gpu in gpus:
9194
try:
9295
if not gpu_mem_cap:
93-
tf.config.experimental.set_memory_growth(gpu, True)
96+
try:
97+
tf.config.set_memory_growth(gpu, True)
98+
except AttributeError:
99+
tf.config.experimental.set_memory_growth(gpu, True)
100+
94101
else:
95-
tf.config.experimental.set_virtual_device_configuration(
96-
gpu, [
97-
tf.config.experimental.VirtualDeviceConfiguration(
98-
memory_limit=gpu_mem_cap
99-
)
100-
]
101-
)
102+
try:
103+
set_virtual_device_configuration = tf.config.set_virtual_device_configuration
104+
device_config = tf.config.LogicalDeviceConfiguration(memory_limit=gpu_mem_cap)
105+
except AttributeError:
106+
set_virtual_device_configuration = tf.config.experimental.set_virtual_device_configuration
107+
device_config = tf.config.experimental.VirtualDeviceConfiguration(memory_limit=gpu_mem_cap)
108+
109+
set_virtual_device_configuration(gpu, [device_config])
102110
except RuntimeError as e:
103111
print('Can not set GPU memory config', e)
104112

0 commit comments

Comments
 (0)