@@ -78,7 +78,10 @@ def __init__(self, args):
78
78
self ._config_gpu_memory (self ._args .gpu_mem_cap )
79
79
80
80
def _config_gpu_memory (self , gpu_mem_cap ):
81
- gpus = tf .config .experimental .list_physical_devices ('GPU' )
81
+ try :
82
+ gpus = tf .config .list_physical_devices ('GPU' )
83
+ except AttributeError :
84
+ gpus = tf .config .experimental .list_physical_devices ('GPU' )
82
85
83
86
if not gpus :
84
87
raise RuntimeError ("No GPUs has been found." )
@@ -90,15 +93,20 @@ def _config_gpu_memory(self, gpu_mem_cap):
90
93
for gpu in gpus :
91
94
try :
92
95
if not gpu_mem_cap :
93
- tf .config .experimental .set_memory_growth (gpu , True )
96
+ try :
97
+ tf .config .set_memory_growth (gpu , True )
98
+ except AttributeError :
99
+ tf .config .experimental .set_memory_growth (gpu , True )
100
+
94
101
else :
95
- tf .config .experimental .set_virtual_device_configuration (
96
- gpu , [
97
- tf .config .experimental .VirtualDeviceConfiguration (
98
- memory_limit = gpu_mem_cap
99
- )
100
- ]
101
- )
102
+ try :
103
+ set_virtual_device_configuration = tf .config .set_virtual_device_configuration
104
+ device_config = tf .config .LogicalDeviceConfiguration (memory_limit = gpu_mem_cap )
105
+ except AttributeError :
106
+ set_virtual_device_configuration = tf .config .experimental .set_virtual_device_configuration
107
+ device_config = tf .config .experimental .VirtualDeviceConfiguration (memory_limit = gpu_mem_cap )
108
+
109
+ set_virtual_device_configuration (gpu , [device_config ])
102
110
except RuntimeError as e :
103
111
print ('Can not set GPU memory config' , e )
104
112
0 commit comments