Skip to content

Commit 4f70f1b

Browse files
BENR0martindurant
andauthored
Fix open_local returning list for pathlib.Path (#1418)
--------- Co-authored-by: Martin Durant <[email protected]>
1 parent 5f268e4 commit 4f70f1b

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

fsspec/core.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from __future__ import annotations
2+
13
import io
24
import logging
35
import os
46
import re
57
from glob import has_magic
8+
from pathlib import Path
69

710
# for backwards compat, we export cache things from here too
811
from .caching import ( # noqa: F401
@@ -469,7 +472,11 @@ def open(
469472
return out[0]
470473

471474

472-
def open_local(url, mode="rb", **storage_options):
475+
def open_local(
476+
url: str | list[str] | Path | list[Path],
477+
mode: str = "rb",
478+
**storage_options: dict,
479+
) -> str | list[str]:
473480
"""Open file(s) which can be resolved to local
474481
475482
For files which either are local, or get downloaded upon open
@@ -493,7 +500,7 @@ def open_local(url, mode="rb", **storage_options):
493500
)
494501
with of as files:
495502
paths = [f.name for f in files]
496-
if isinstance(url, str) and not has_magic(url):
503+
if (isinstance(url, str) and not has_magic(url)) or isinstance(url, Path):
497504
return paths[0]
498505
return paths
499506

fsspec/tests/test_core.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tempfile
44
import zipfile
55
from contextlib import contextmanager
6+
from pathlib import Path
67

78
import pytest
89

@@ -101,7 +102,7 @@ def test_openfile_open(m):
101102
assert m.size("somepath") == 5
102103

103104

104-
def test_open_local():
105+
def test_open_local_w_cache():
105106
d1 = str(tempfile.mkdtemp())
106107
f1 = os.path.join(d1, "f1")
107108
open(f1, "w").write("test1")
@@ -112,6 +113,45 @@ def test_open_local():
112113
assert d2 in fn
113114

114115

116+
def test_open_local_w_magic():
117+
d1 = str(tempfile.mkdtemp())
118+
f1 = os.path.join(d1, "f1")
119+
open(f1, "w").write("test1")
120+
fn = open_local(os.path.join(d1, "f*"))
121+
assert len(fn) == 1
122+
assert isinstance(fn, list)
123+
124+
125+
def test_open_local_w_list_of_str():
126+
d1 = str(tempfile.mkdtemp())
127+
f1 = os.path.join(d1, "f1")
128+
open(f1, "w").write("test1")
129+
fn = open_local([f1, f1])
130+
assert len(fn) == 2
131+
assert isinstance(fn, list)
132+
assert all(isinstance(elem, str) for elem in fn)
133+
134+
135+
def test_open_local_w_path():
136+
d1 = str(tempfile.mkdtemp())
137+
f1 = os.path.join(d1, "f1")
138+
open(f1, "w").write("test1")
139+
p = Path(f1)
140+
fn = open_local(p)
141+
assert isinstance(fn, str)
142+
143+
144+
def test_open_local_w_list_of_path():
145+
d1 = str(tempfile.mkdtemp())
146+
f1 = os.path.join(d1, "f1")
147+
open(f1, "w").write("test1")
148+
p = Path(f1)
149+
fn = open_local([p, p])
150+
assert len(fn) == 2
151+
assert isinstance(fn, list)
152+
assert all(isinstance(elem, str) for elem in fn)
153+
154+
115155
def test_xz_lzma_compressions():
116156
pytest.importorskip("lzma")
117157
# Ensure that both 'xz' and 'lzma' compression names can be parsed

0 commit comments

Comments
 (0)