diff --git a/src/lightning/pytorch/loggers/mlflow.py b/src/lightning/pytorch/loggers/mlflow.py index ff9b2b0d7e542..a904a79861e3a 100644 --- a/src/lightning/pytorch/loggers/mlflow.py +++ b/src/lightning/pytorch/loggers/mlflow.py @@ -300,7 +300,18 @@ def save_dir(self) -> Optional[str]: """ if self._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX): - return self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :] + # Handle both proper file URIs (file:///path) and legacy format (file:/path) + uri_without_prefix = self._tracking_uri[len(LOCAL_FILE_URI_PREFIX) :] + + # If it starts with ///, it's a proper file URI, use urlparse + if uri_without_prefix.startswith("///"): + from urllib.parse import urlparse + from urllib.request import url2pathname + + parsed_uri = urlparse(self._tracking_uri) + return url2pathname(parsed_uri.path) + # Legacy format: file:/path or file:./path - return as-is + return uri_without_prefix return None @property diff --git a/tests/tests_pytorch/loggers/test_mlflow.py b/tests/tests_pytorch/loggers/test_mlflow.py index c7f9dbe1fe2c6..9de0ad23a93cf 100644 --- a/tests/tests_pytorch/loggers/test_mlflow.py +++ b/tests/tests_pytorch/loggers/test_mlflow.py @@ -427,3 +427,43 @@ def test_set_tracking_uri(mlflow_mock): mlflow_mock.set_tracking_uri.assert_not_called() _ = logger.experiment mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri") + + +@mock.patch("lightning.pytorch.loggers.mlflow._get_resolve_tags", Mock()) +def test_mlflow_logger_save_dir_file_uri_handling(mlflow_mock): + """Test that save_dir correctly handles file URIs, especially on Windows.""" + import platform + + # Test proper Windows-style absolute file URI (the main fix) + logger_win = MLFlowLogger(tracking_uri="file:///C:/Dev/example/mlruns") + result_win = logger_win.save_dir + expected_win = "C:\\Dev\\example\\mlruns" if platform.system() == "Windows" else "/C:/Dev/example/mlruns" + assert result_win == expected_win + + # Test proper Unix-style absolute file URI + logger_unix = MLFlowLogger(tracking_uri="file:///home/user/mlruns") + result_unix = logger_unix.save_dir + expected_unix = "\\home\\user\\mlruns" if platform.system() == "Windows" else "/home/user/mlruns" + assert result_unix == expected_unix + + # Test proper file URI with special characters and spaces + logger_special = MLFlowLogger(tracking_uri="file:///path/with%20spaces/mlruns") + result_special = logger_special.save_dir + expected_special = "\\path\\with spaces\\mlruns" if platform.system() == "Windows" else "/path/with spaces/mlruns" + assert result_special == expected_special + + # Test legacy format used by constructor (file:/path - should return as-is) + logger_legacy = MLFlowLogger(tracking_uri="file:/tmp/mlruns") + result_legacy = logger_legacy.save_dir + expected_legacy = "/tmp/mlruns" + assert result_legacy == expected_legacy + + # Test legacy relative format + logger_rel = MLFlowLogger(tracking_uri="file:./mlruns") + result_rel = logger_rel.save_dir + expected_rel = "./mlruns" + assert result_rel == expected_rel + + # Test non-file URI (should return None) + logger_http = MLFlowLogger(tracking_uri="http://localhost:8080") + assert logger_http.save_dir is None