diff --git a/tests/e2e/mnist.py b/tests/e2e/mnist.py index 143a6b6c..bc9101a8 100644 --- a/tests/e2e/mnist.py +++ b/tests/e2e/mnist.py @@ -42,6 +42,11 @@ print("ACCELERATOR: is ", os.getenv("ACCELERATOR")) ACCELERATOR = os.getenv("ACCELERATOR") +# If GPU is requested but CUDA is not available, fall back to CPU +if ACCELERATOR == "gpu" and not torch.cuda.is_available(): + print("Warning: GPU requested but CUDA is not available. Falling back to CPU.") + ACCELERATOR = "cpu" + STORAGE_BUCKET_EXISTS = "AWS_DEFAULT_ENDPOINT" in os.environ print("STORAGE_BUCKET_EXISTS: ", STORAGE_BUCKET_EXISTS) diff --git a/tests/e2e/mnist_pip_requirements.txt b/tests/e2e/mnist_pip_requirements.txt index 60811f18..3df95458 100644 --- a/tests/e2e/mnist_pip_requirements.txt +++ b/tests/e2e/mnist_pip_requirements.txt @@ -1,4 +1,6 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +torch==2.5.1 +torchvision==0.20.1 pytorch_lightning==1.9.5 torchmetrics==0.9.1 -torchvision==0.20.1 minio