diff --git a/intervention/circle_probe_interventions.py b/intervention/circle_probe_interventions.py index d936af3..8939dbd 100644 --- a/intervention/circle_probe_interventions.py +++ b/intervention/circle_probe_interventions.py @@ -41,7 +41,7 @@ choices=["llama", "mistral"], help="Choose 'llama' or 'mistral' model", ) - parser.add_argument("--device", type=int, default=4, help="CUDA device number") + parser.add_argument("--device", type=str, default="4", help="CUDA device number, or full device string") parser.add_argument( "--use_inverse_regression_probe", action="store_true", @@ -73,7 +73,7 @@ help="Probe on linear representation with center of 0.", ) args = parser.parse_args() - device = f"cuda:{args.device}" + device = (f"cuda:{args.device}" if torch.cuda.is_available() else "cpu") if args.device.isnumeric() else args.device day_month_choice = args.problem_type circle_letter = args.intervene_on model_name = args.model @@ -100,7 +100,7 @@ # use_inverse_regression_probe = False # intervention_pca_k = 5 - device = "cuda:4" + device = "cuda:4" if torch.cuda.is_available() else "cpu" circle_letter = "c" day_month_choice = "day" model_name = "mistral" diff --git a/intervention/days_of_week_task.py b/intervention/days_of_week_task.py index cbb5fa8..4e8e061 100644 --- a/intervention/days_of_week_task.py +++ b/intervention/days_of_week_task.py @@ -11,7 +11,7 @@ from task import activation_patching -device = "cuda:4" +device = "cuda:4" if torch.cuda.is_available() else "cpu" # # %% diff --git a/intervention/months_of_year_task.py b/intervention/months_of_year_task.py index 161ef29..6adfef3 100644 --- a/intervention/months_of_year_task.py +++ b/intervention/months_of_year_task.py @@ -5,13 +5,14 @@ setup_notebook() +import torch import numpy as np import transformer_lens from task import Problem, get_acts, plot_pca, get_all_acts, get_acts_pca from task import activation_patching -device = "cuda:4" +device = "cuda:4" if torch.cuda.is_available() else "cpu" # # %%