|
6 | 6 | """ |
7 | 7 | import multiprocessing |
8 | 8 | import os |
| 9 | +import re |
9 | 10 | import sys |
10 | 11 | import tempfile |
11 | 12 | from unittest.mock import MagicMock, patch |
@@ -207,23 +208,51 @@ def cxx_search_dirs(blas_libs, mock_system): |
207 | 208 | yield f"libraries: ={d}".encode(sys.stdout.encoding), flags |
208 | 209 |
|
209 | 210 |
|
| 211 | +@pytest.fixture( |
| 212 | + scope="function", params=[False, True], ids=["Working_CXX", "Broken_CXX"] |
| 213 | +) |
| 214 | +def cxx_search_dirs_status(request): |
| 215 | + return request.param |
| 216 | + |
| 217 | + |
210 | 218 | @patch("pytensor.link.c.cmodule.std_lib_dirs", return_value=[]) |
211 | 219 | @patch("pytensor.link.c.cmodule.check_mkl_openmp", return_value=None) |
212 | 220 | def test_default_blas_ldflags( |
213 | | - mock_std_lib_dirs, mock_check_mkl_openmp, cxx_search_dirs |
| 221 | + mock_std_lib_dirs, mock_check_mkl_openmp, cxx_search_dirs, cxx_search_dirs_status |
214 | 222 | ): |
215 | 223 | cxx_search_dirs, expected_blas_ldflags = cxx_search_dirs |
216 | 224 | mock_process = MagicMock() |
217 | | - mock_process.communicate = lambda *args, **kwargs: (cxx_search_dirs, None) |
| 225 | + if cxx_search_dirs_status: |
| 226 | + error_message = "" |
| 227 | + mock_process.communicate = lambda *args, **kwargs: (cxx_search_dirs, b"") |
| 228 | + mock_process.returncode = 0 |
| 229 | + else: |
| 230 | + error_message = "Unsupported argument -print-search-dirs" |
| 231 | + error_message_bytes = error_message.encode(sys.stderr.encoding) |
| 232 | + mock_process.communicate = lambda *args, **kwargs: (b"", error_message_bytes) |
| 233 | + mock_process.returncode = 1 |
218 | 234 | with patch("pytensor.link.c.cmodule.subprocess_Popen", return_value=mock_process): |
219 | 235 | with patch.object( |
220 | 236 | pytensor.link.c.cmodule.GCC_compiler, |
221 | 237 | "try_compile_tmp", |
222 | 238 | return_value=(True, True), |
223 | 239 | ): |
224 | | - assert set(default_blas_ldflags().split(" ")) == set( |
225 | | - expected_blas_ldflags.split(" ") |
226 | | - ) |
| 240 | + if cxx_search_dirs_status: |
| 241 | + assert set(default_blas_ldflags().split(" ")) == set( |
| 242 | + expected_blas_ldflags.split(" ") |
| 243 | + ) |
| 244 | + else: |
| 245 | + expected_warning = re.escape( |
| 246 | + "Pytensor cxx failed to communicate its search dirs. As a consequence, " |
| 247 | + "it might not be possible to automatically determine the blas link flags to use.\n" |
| 248 | + f"Command that was run: {config.cxx} -print-search-dirs\n" |
| 249 | + f"Output printed to stderr: {error_message}" |
| 250 | + ) |
| 251 | + with pytest.warns( |
| 252 | + UserWarning, |
| 253 | + match=expected_warning, |
| 254 | + ): |
| 255 | + assert default_blas_ldflags() == "" |
227 | 256 |
|
228 | 257 |
|
229 | 258 | def test_default_blas_ldflags_no_cxx(): |
|
0 commit comments