Skip to content

Commit 5759518

Browse files
committed
fix small accelerate nits
1 parent 7191802 commit 5759518

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

evaluation/generation/generate.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ def get_args():
1313
parser.add_argument("--greedy", action="store_true")
1414
parser.add_argument("--top-k", type=int, default=0)
1515
parser.add_argument("--offload_folder", type=str, help="offload folder for accelerate", default="./offload")
16+
parser.add_argument("--max_memory", type=str, help="max memory per GPU", default="30GB")
17+
parser.add_argument("--max_cpu_memory", type=str, help="max memory on CPU", default="300GB")
1618

1719
return parser.parse_args()
1820

@@ -40,9 +42,12 @@ def main():
4042
args.checkpoint,
4143
device_map="auto" if args.parallelize else None,
4244
torch_dtype=torch.bfloat16,
43-
revision="gs{}".format(args.global_step) if args.global_step else None
44-
offload_folder=args.offload_folder is args.parallelize else None,
45+
revision="gs{}".format(args.global_step) if args.global_step else None,
46+
max_memory=args.max_memory if args.parallelize else None,
47+
max_cpu_memory=args.max_cpu_memory if args.parallelize else None,
48+
offload_folder=args.offload_folder if args.parallelize else None,
4549
)
50+
4651
print(f"Loaded model in {datetime.datetime.now() - start}")
4752

4853
text = ''

0 commit comments

Comments
 (0)