From 94360a51fdaf905b51b5bd617c870216560ad14a Mon Sep 17 00:00:00 2001 From: whitewatercn Date: Mon, 10 Jun 2024 20:05:02 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86--data-path?= =?UTF-8?q?=E4=B8=BA=E5=BF=85=E9=A1=BB=E4=BD=BF=E7=94=A8=E7=9A=84=E5=8F=82?= =?UTF-8?q?=E6=95=B0,=E5=8F=96=E6=B6=88=E5=85=B6=E9=BB=98=E8=AE=A4?= =?UTF-8?q?=E8=AE=BE=E7=BD=AE,=E6=8F=90=E7=A4=BA=E7=94=A8=E6=88=B7?= =?UTF-8?q?=E8=AE=BE=E7=BD=AE=E4=B8=BA=E8=87=AA=E5=B7=B1=20=E7=9A=84datase?= =?UTF-8?q?t=E8=B7=AF=E5=BE=84,=E4=BF=AE=E6=94=B9=E4=BA=86--model=E9=BB=98?= =?UTF-8?q?=E8=AE=A4=E5=8F=82=E6=95=B0=E4=B8=BAcpu,=E4=BB=A5=E5=85=8D?= =?UTF-8?q?=E6=97=A0gpu=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- code/chapter-8/01_classification/train_main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/code/chapter-8/01_classification/train_main.py b/code/chapter-8/01_classification/train_main.py index 77e8c71..58fc3b9 100644 --- a/code/chapter-8/01_classification/train_main.py +++ b/code/chapter-8/01_classification/train_main.py @@ -28,10 +28,10 @@ def get_args_parser(add_help=True): parser = argparse.ArgumentParser(description="PyTorch Classification Training", add_help=add_help) - parser.add_argument("--data-path", default=r"G:\deep_learning_data\chest_xray", type=str, help="dataset path") + parser.add_argument("--data-path", required = True, type=str, help="dataset path, like G:\deep_learning_data\chest_xray/train") parser.add_argument("--model", default="convnext-tiny", type=str, help="model name; resnet50/convnext/convnext-tiny") - parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)") + parser.add_argument("--device", default="cpu", type=str, help="device (Use cuda or cpu Default: cuda)") parser.add_argument( "-b", "--batch-size", default=8, type=int, help="images per gpu, the total batch size is $NGPU x batch_size" ) From b788ae4ad76a08fa34288e62109d893013c63edc Mon Sep 17 00:00:00 2001 From: whitewatercn Date: Mon, 10 Jun 2024 20:37:50 +0800 Subject: [PATCH 2/2] add support for apple sillicon --- code/chapter-8/01_classification/train_main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/code/chapter-8/01_classification/train_main.py b/code/chapter-8/01_classification/train_main.py index 58fc3b9..c431bcd 100644 --- a/code/chapter-8/01_classification/train_main.py +++ b/code/chapter-8/01_classification/train_main.py @@ -216,5 +216,7 @@ def main(args): if __name__ == "__main__": args = get_args_parser().parse_args() utils.setup_seed(args.random_seed) - args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device_choices = ["cuda", "mps", "cpu"] + device = torch.device(args.device) if args.device in device_choices else torch.device("cpu") + args.device = device main(args)