@@ -13,6 +13,8 @@ def get_args():
13
13
parser .add_argument ("--greedy" , action = "store_true" )
14
14
parser .add_argument ("--top-k" , type = int , default = 0 )
15
15
parser .add_argument ("--offload_folder" , type = str , help = "offload folder for accelerate" , default = "./offload" )
16
+ parser .add_argument ("--max_memory" , type = str , help = "max memory per GPU" , default = "30GB" )
17
+ parser .add_argument ("--max_cpu_memory" , type = str , help = "max memory on CPU" , default = "300GB" )
16
18
17
19
return parser .parse_args ()
18
20
@@ -40,9 +42,12 @@ def main():
40
42
args .checkpoint ,
41
43
device_map = "auto" if args .parallelize else None ,
42
44
torch_dtype = torch .bfloat16 ,
43
- revision = "gs{}" .format (args .global_step ) if args .global_step else None
44
- offload_folder = args .offload_folder is args .parallelize else None ,
45
+ revision = "gs{}" .format (args .global_step ) if args .global_step else None ,
46
+ max_memory = args .max_memory if args .parallelize else None ,
47
+ max_cpu_memory = args .max_cpu_memory if args .parallelize else None ,
48
+ offload_folder = args .offload_folder if args .parallelize else None ,
45
49
)
50
+
46
51
print (f"Loaded model in { datetime .datetime .now () - start } " )
47
52
48
53
text = ''
0 commit comments