Skip to content

Commit f0b755a

Browse files
authored
Hot fix: Fix path resolution (#29)
1 parent 51b7fee commit f0b755a

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

litdata/processing/data_processor.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,25 @@ def _get_item_filesizes(items: List[Any], base_path: str = "") -> List[int]:
339339
return item_sizes
340340

341341

342+
def _to_path(element: str) -> str:
343+
return element if _IS_IN_STUDIO and element.startswith("/teamspace") else str(Path(element).resolve())
344+
345+
346+
def _is_path(input_dir: Optional[str], element: Any) -> bool:
347+
if not isinstance(element, str):
348+
return False
349+
350+
if _IS_IN_STUDIO and input_dir is not None:
351+
if element.startswith(input_dir):
352+
return True
353+
354+
element = str(Path(element).absolute())
355+
if element.startswith(input_dir):
356+
return True
357+
358+
return os.path.exists(element)
359+
360+
342361
class BaseWorker:
343362
def __init__(
344363
self,
@@ -381,7 +400,6 @@ def __init__(
381400
self.remove_queue: Queue = Queue()
382401
self.progress_queue: Queue = progress_queue
383402
self.error_queue: Queue = error_queue
384-
self._collected_items = 0
385403
self._counter = 0
386404
self._last_time = time()
387405
self._index_counter = 0
@@ -504,22 +522,13 @@ def _collect_paths(self) -> None:
504522
for item in self.items:
505523
flattened_item, spec = tree_flatten(item)
506524

507-
def is_path(element: Any) -> bool:
508-
if not isinstance(element, str):
509-
return False
510-
511-
element: str = str(Path(element).resolve())
512-
if _IS_IN_STUDIO and self.input_dir.path is not None:
513-
if self.input_dir.path.startswith("/teamspace/studios/this_studio"):
514-
return os.path.exists(element)
515-
return element.startswith(self.input_dir.path)
516-
return os.path.exists(element)
517-
518525
# For speed reasons, we assume starting with `self.input_dir` is enough to be a real file.
519526
# Other alternative would be too slow.
520527
# TODO: Try using dictionary for higher accurary.
521528
indexed_paths = {
522-
index: str(Path(element).resolve()) for index, element in enumerate(flattened_item) if is_path(element)
529+
index: _to_path(element)
530+
for index, element in enumerate(flattened_item)
531+
if _is_path(self.input_dir.path, element)
523532
}
524533

525534
if len(indexed_paths) == 0:
@@ -537,7 +546,6 @@ def is_path(element: Any) -> bool:
537546
self.paths.append(paths)
538547

539548
items.append(tree_unflatten(flattened_item, spec))
540-
self._collected_items += 1
541549

542550
self.items = items
543551

tests/processing/test_data_processor.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
DataTransformRecipe,
2222
_download_data_target,
2323
_get_item_filesizes,
24+
_is_path,
2425
_map_items_to_workers_sequentially,
2526
_map_items_to_workers_weighted,
2627
_remove_target,
28+
_to_path,
2729
_upload_fn,
2830
_wait_for_disk_usage_higher_than_threshold,
2931
_wait_for_file_to_exist,
@@ -1136,3 +1138,24 @@ def test_load_torch_audio_from_wav_file(tmpdir, compression):
11361138
tensor = torchaudio.load(sample)
11371139
assert tensor[0].shape == torch.Size([1, 16000])
11381140
assert tensor[1] == 16000
1141+
1142+
1143+
def test_is_path_valid_in_studio(monkeypatch, tmpdir):
1144+
filepath = os.path.join(tmpdir, "a.png")
1145+
with open(filepath, "w") as f:
1146+
f.write("Hello World")
1147+
1148+
monkeypatch.setattr(data_processor_module, "_IS_IN_STUDIO", True)
1149+
1150+
assert _is_path("/teamspace/studios/this_studio", "/teamspace/studios/this_studio/a.png")
1151+
assert _is_path("/teamspace/studios/this_studio", filepath)
1152+
1153+
1154+
@pytest.mark.skipif(sys.platform == "win32", reason="skip windows")
1155+
def test_to_path(tmpdir):
1156+
filepath = os.path.join(tmpdir, "a.png")
1157+
with open(filepath, "w") as f:
1158+
f.write("Hello World")
1159+
1160+
assert _to_path("/teamspace/studios/this_studio/a.png") == "/teamspace/studios/this_studio/a.png"
1161+
assert _to_path(filepath) == filepath

0 commit comments

Comments
 (0)