Skip to content

Commit 4943871

Browse files
committed
Add new test case
1 parent ada2a5e commit 4943871

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

skillsnetwork/core.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -114,25 +114,33 @@ async def _get_chunks(url: str, chunk_size: int) -> Generator[bytes, None, None]
114114
raise Exception(f"Failed to read dataset at {url}") from None
115115

116116

117+
def _rmrf(path: Path) -> None:
118+
if path.is_dir():
119+
shutil.rmtree(path)
120+
else:
121+
path.unlink()
122+
123+
117124
def _verify_files_dont_exist(
118-
paths: Iterable[Union[str, Path]], remove_if_exist: bool = False
125+
paths: Iterable[Path], remove_if_exist: bool = False
119126
) -> None:
120127
"""
121128
Verifies all paths in 'paths' don't exist.
122129
:param paths: A iterable of strs or pathlib.Paths.
123130
:param remove_if_exist=False: Removes file at path if they already exist.
124131
:returns: None
125-
:raises FileExistsError: On the first path found that already exists.
132+
:raises FileExistsError: On the first path found that already exists if remove_if_exist is False.
126133
"""
127134
for path in paths:
128-
path = Path(path)
129-
if path.exists():
135+
# Could be a broken symlink => path.exists() is False
136+
if path.exists() or path.is_symlink():
130137
if remove_if_exist:
131-
if path.is_symlink():
132-
realpath = path.resolve()
133-
path.unlink(realpath)
134-
else:
135-
shutil.rmtree(path)
138+
while path.is_symlink():
139+
temp = path.readlink()
140+
path.unlink(missing_ok=True)
141+
path = temp
142+
if path.exists():
143+
_rmrf(path)
136144
else:
137145
raise FileExistsError(f"Error: File '{path}' already exists.")
138146

@@ -254,9 +262,9 @@ async def prepare(
254262
path / child.name
255263
for child in map(Path, tf.getnames())
256264
if len(child.parents) == 1 and _is_file_to_symlink(child)
257-
],
258-
overwrite,
259-
) # Only check if top-level fileobject
265+
], # Only check if top-level fileobject
266+
remove_if_exist=overwrite,
267+
)
260268
pbar = tqdm(iterable=tf.getmembers(), total=len(tf.getmembers()))
261269
pbar.set_description(f"Extracting {filename}")
262270
for member in pbar:
@@ -269,22 +277,24 @@ async def prepare(
269277
path / child.name
270278
for child in map(Path, zf.namelist())
271279
if len(child.parents) == 1 and _is_file_to_symlink(child)
272-
],
273-
overwrite,
280+
], # Only check if top-level fileobject
281+
remove_if_exist=overwrite,
274282
)
275283
pbar = tqdm(iterable=zf.infolist(), total=len(zf.infolist()))
276284
pbar.set_description(f"Extracting {filename}")
277285
for member in pbar:
278286
zf.extract(member=member, path=extract_dir)
279287
tmp_download_file.unlink()
280288
else:
281-
_verify_files_dont_exist([path / filename], overwrite)
289+
_verify_files_dont_exist([path / filename], remove_if_exist=overwrite)
282290
shutil.move(tmp_download_file, extract_dir / filename)
283291

284292
# If in jupyterlite environment, the extract_dir = path, so the files are already there.
285293
if not _is_jupyterlite():
286294
# If not in jupyterlite environment, symlink top-level file objects in extract_dir
287295
for child in filter(_is_file_to_symlink, extract_dir.iterdir()):
296+
if (path / child.name).is_symlink() and overwrite:
297+
(path / child.name).unlink()
288298
(path / child.name).symlink_to(child, target_is_directory=child.is_dir())
289299

290300
if verbose:

tests/test_skillsnetwork.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,21 @@ async def test_prepare_non_compressed_dataset_with_path(httpserver):
134134
await skillsnetwork.prepare_dataset(httpserver.url_for(url), path=path)
135135
assert expected_path.exists()
136136
expected_path.unlink()
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_prepare_non_compressed_dataset_with_overwrite(httpserver):
141+
url = "/test.csv"
142+
expected_path = Path("./test.csv")
143+
with open("tests/test.csv", "rb") as expected_data:
144+
httpserver.expect_request(url).respond_with_data(expected_data)
145+
await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True)
146+
assert expected_path.exists()
147+
httpserver.clear()
148+
print(expected_path.absolute(), expected_path.absolute().exists())
149+
with open("tests/test.csv", "rb") as expected_data:
150+
httpserver.expect_request(url).respond_with_data(expected_data)
151+
await skillsnetwork.prepare_dataset(httpserver.url_for(url), overwrite=True)
152+
assert expected_path.exists()
153+
assert Path(expected_path).stat().st_size == 540
154+
expected_path.unlink()

0 commit comments

Comments
 (0)