Skip to content

Commit 885163e

Browse files
authored
File patterns in read_storage: wildcard, globstar & braces (#1309)
1 parent 8355e28 commit 885163e

File tree

9 files changed

+776
-37
lines changed

9 files changed

+776
-37
lines changed

src/datachain/client/fsspec.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
DELIMITER = "/" # Path delimiter.
4545

4646
DATA_SOURCE_URI_PATTERN = re.compile(r"^[\w]+:\/\/.*$")
47+
CLOUD_STORAGE_PROTOCOLS = {"s3", "gs", "az", "hf"}
4748

4849
ResultQueue = asyncio.Queue[Optional[Sequence["File"]]]
4950

@@ -62,6 +63,16 @@ def _is_win_local_path(uri: str) -> bool:
6263
return False
6364

6465

66+
def is_cloud_uri(uri: str) -> bool:
67+
protocol = urlparse(uri).scheme
68+
return protocol in CLOUD_STORAGE_PROTOCOLS
69+
70+
71+
def get_cloud_schemes() -> list[str]:
72+
"""Get list of cloud storage scheme prefixes."""
73+
return [f"{p}://" for p in CLOUD_STORAGE_PROTOCOLS]
74+
75+
6576
class Bucket(NamedTuple):
6677
name: str
6778
uri: "StorageURI"

src/datachain/lib/dc/storage.py

Lines changed: 67 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
from functools import reduce
44
from typing import TYPE_CHECKING, Optional, Union
55

6+
from datachain.lib.dc.storage_pattern import (
7+
apply_glob_filter,
8+
expand_brace_pattern,
9+
should_use_recursion,
10+
split_uri_pattern,
11+
validate_cloud_bucket_name,
12+
)
613
from datachain.lib.file import FileType, get_file_type
714
from datachain.lib.listing import get_file_info, get_listing, list_bucket, ls
815
from datachain.query import Session
@@ -38,14 +45,18 @@ def read_storage(
3845
It returns the chain itself as usual.
3946
4047
Parameters:
41-
uri: storage URI with directory or list of URIs.
42-
URIs must start with storage prefix such
43-
as `s3://`, `gs://`, `az://` or "file:///"
48+
uri: Storage path(s) or URI(s). Can be a local path or start with a
49+
storage prefix like `s3://`, `gs://`, `az://`, `hf://` or "file:///".
50+
Supports glob patterns:
51+
- `*` : wildcard
52+
- `**` : recursive wildcard
53+
- `?` : single character
54+
- `{a,b}` : brace expansion
4455
type: read file as "binary", "text", or "image" data. Default is "binary".
4556
recursive: search recursively for the given path.
46-
column: Created column name.
57+
column: Column name that will contain File objects. Default is "file".
4758
update: force storage reindexing. Default is False.
48-
anon: If True, we will treat cloud bucket as public one
59+
anon: If True, we will treat cloud bucket as public one.
4960
client_config: Optional client configuration for the storage client.
5061
delta: If True, only process new or changed files instead of reprocessing
5162
everything. This saves time by skipping files that were already processed in
@@ -80,12 +91,19 @@ def read_storage(
8091
chain = dc.read_storage("s3://my-bucket/my-dir")
8192
```
8293
94+
Match all .json files recursively using glob pattern
95+
```py
96+
chain = dc.read_storage("gs://bucket/meta/**/*.json")
97+
```
98+
99+
Match image file extensions for directories with pattern
100+
```py
101+
chain = dc.read_storage("s3://bucket/202?/**/*.{jpg,jpeg,png}")
102+
```
103+
83104
Multiple URIs:
84105
```python
85-
chain = dc.read_storage([
86-
"s3://bucket1/dir1",
87-
"s3://bucket2/dir2"
88-
])
106+
chain = dc.read_storage(["s3://my-bkt/dir1", "s3://bucket2/dir2/dir3"])
89107
```
90108
91109
With AWS S3-compatible storage:
@@ -95,19 +113,6 @@ def read_storage(
95113
client_config = {"aws_endpoint_url": "<minio-endpoint-url>"}
96114
)
97115
```
98-
99-
Pass existing session
100-
```py
101-
session = Session.get()
102-
chain = dc.read_storage([
103-
"path/to/dir1",
104-
"path/to/dir2"
105-
], session=session, recursive=True)
106-
```
107-
108-
Note:
109-
When using multiple URIs with `update=True`, the function optimizes by
110-
avoiding redundant updates for URIs pointing to the same storage location.
111116
"""
112117
from .datachain import DataChain
113118
from .datasets import read_dataset
@@ -130,13 +135,36 @@ def read_storage(
130135
if not uris:
131136
raise ValueError("No URIs provided")
132137

138+
# Then expand all URIs that contain brace patterns
139+
expanded_uris = []
140+
for single_uri in uris:
141+
uri_str = str(single_uri)
142+
validate_cloud_bucket_name(uri_str)
143+
expanded_uris.extend(expand_brace_pattern(uri_str))
144+
145+
# Now process each expanded URI
133146
chains = []
134147
listed_ds_name = set()
135148
file_values = []
136149

137-
for single_uri in uris:
150+
updated_uris = set()
151+
152+
for single_uri in expanded_uris:
153+
# Check if URI contains glob patterns and split them
154+
base_uri, glob_pattern = split_uri_pattern(single_uri)
155+
156+
# If a pattern is found, use the base_uri for listing
157+
# The pattern will be used for filtering later
158+
list_uri_to_use = base_uri if glob_pattern else single_uri
159+
160+
# Avoid double updates for the same URI
161+
update_single_uri = False
162+
if update and (list_uri_to_use not in updated_uris):
163+
updated_uris.add(list_uri_to_use)
164+
update_single_uri = True
165+
138166
list_ds_name, list_uri, list_path, list_ds_exists = get_listing(
139-
single_uri, session, update=update
167+
list_uri_to_use, session, update=update_single_uri
140168
)
141169

142170
# list_ds_name is None if object is a file, we don't want to use cache
@@ -185,7 +213,21 @@ def lst_fn(ds_name, lst_uri):
185213
lambda ds_name=list_ds_name, lst_uri=list_uri: lst_fn(ds_name, lst_uri)
186214
)
187215

188-
chains.append(ls(dc, list_path, recursive=recursive, column=column))
216+
# If a glob pattern was detected, use it for filtering
217+
# Otherwise, use the original list_path from get_listing
218+
if glob_pattern:
219+
# Determine if we should use recursive listing based on the pattern
220+
use_recursive = should_use_recursion(glob_pattern, recursive or False)
221+
222+
# Apply glob filter - no need for brace expansion here as it's done above
223+
chain = apply_glob_filter(
224+
dc, glob_pattern, list_path, use_recursive, column
225+
)
226+
chains.append(chain)
227+
else:
228+
# No glob pattern detected, use normal ls behavior
229+
chains.append(ls(dc, list_path, recursive=recursive, column=column))
230+
189231
listed_ds_name.add(list_ds_name)
190232

191233
storage_chain = None if not chains else reduce(lambda x, y: x.union(y), chains)

0 commit comments

Comments
 (0)