Skip to content

Commit 0e8ff70

Browse files
committed
513 refactor recursive ls and add test
1 parent fcc1c8e commit 0e8ff70

File tree

3 files changed

+70
-42
lines changed

3 files changed

+70
-42
lines changed

databricks_cli/dbfs/api.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,18 @@ class DbfsApi(object):
9191
def __init__(self, api_client):
9292
self.client = DbfsService(api_client)
9393

94-
def list_files(self, dbfs_path, headers=None):
95-
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
94+
def _recursive_list(self, **kwargs):
95+
paths = self.client.list_files(**kwargs)
96+
files = [p for p in paths if not p.is_dir]
97+
for p in paths:
98+
files = files + self._recursive_list(p) if p.is_dir else files
99+
return files
100+
101+
def list_files(self, dbfs_path, headers=None, is_recursive=False):
102+
if is_recursive:
103+
list_response = self._recursive_list(dbfs_path, headers)
104+
else:
105+
list_response = self.client.list(dbfs_path.absolute_path, headers=headers)
96106
if 'files' in list_response:
97107
return [FileInfo.from_json(f) for f in list_response['files']]
98108
else:

databricks_cli/dbfs/cli.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@
3838
@click.option('-l', is_flag=True, default=False,
3939
help="""Displays full information including size, file type
4040
and modification time since Epoch in milliseconds.""")
41-
@click.option('--recursive', is_flag=True, default=False,
41+
@click.option('--recursive', '-r', is_flag=True, default=False,
4242
help='Displays all subdirectories and files.')
4343
@click.argument('dbfs_path', nargs=-1, type=DbfsPathClickType())
4444
@debug_option
4545
@profile_option
4646
@eat_exceptions
4747
@provide_api_client
48-
def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
48+
def ls_cli(api_client, l, absolute, recursive, dbfs_path): # NOQA
4949
"""
5050
List files in DBFS.
5151
"""
@@ -55,20 +55,13 @@ def ls_cli(api_client, l, absolute, dbfs_path): # NOQA
5555
dbfs_path = dbfs_path[0]
5656
else:
5757
error_and_quit('ls can take a maximum of one path.')
58-
59-
def echo_path(files):
60-
table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files],
61-
tablefmt='plain')
62-
click.echo(table)
63-
64-
def recursive_echo(this_dbfs_path):
65-
files = DbfsApi(api_client).list_files(this_dbfs_path)
66-
echo_path(files)
67-
for f in files:
68-
if f.is_dir:
69-
recursive_echo(this_dbfs_path.join(f.basename))
70-
71-
recursive_echo(dbfs_path) if recursive else echo_path(dbfs_path)
58+
59+
DbfsApi(api_client).list_files(dbfs_path, is_recursive=recursive)
60+
absolute = absolute or recursive
61+
62+
table = tabulate([f.to_row(is_long_form=l, is_absolute=absolute) for f in files],
63+
tablefmt='plain')
64+
click.echo(table)
7265

7366

7467
@click.command(context_settings=CONTEXT_SETTINGS)

tests/dbfs/test_api.py

Lines changed: 49 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,29 @@
3434
from databricks_cli.dbfs.dbfs_path import DbfsPath
3535
from databricks_cli.dbfs.exceptions import LocalFileExistsException
3636

37-
TEST_DBFS_PATH = DbfsPath('dbfs:/test')
37+
TEST_DBFS_PATH1 = DbfsPath('dbfs:/test')
38+
TEST_DBFS_PATH2 = DbfsPath('dbfs:/dir/test')
3839
DUMMY_TIME = 1613158406000
39-
TEST_FILE_JSON = {
40+
TEST_FILE_JSON1 = {
4041
'path': '/test',
4142
'is_dir': False,
4243
'file_size': 1,
4344
'modification_time': DUMMY_TIME
4445
}
45-
TEST_FILE_INFO = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
46+
TEST_FILE_JSON2 = {
47+
'path': '/dir/test',
48+
'is_dir': False,
49+
'file_size': 1,
50+
'modification_time': DUMMY_TIME
51+
}
52+
TEST_DIR_JSON = {
53+
'path': '/dir',
54+
'is_dir': True,
55+
'file_size': 0,
56+
'modification_time': DUMMY_TIME
57+
}
58+
TEST_FILE_INFO0 = api.FileInfo(TEST_DBFS_PATH1, False, 1, DUMMY_TIME)
59+
TEST_FILE_INFO1 = api.FileInfo(TEST_DBFS_PATH2, False, 1, DUMMY_TIME)
4660

4761

4862
def get_resource_does_not_exist_exception():
@@ -60,22 +74,22 @@ def get_partial_delete_exception(message="[...] operation has deleted 10 files [
6074

6175
class TestFileInfo(object):
6276
def test_to_row_not_long_form_not_absolute(self):
63-
file_info = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
77+
file_info = api.FileInfo(TEST_DBFS_PATH1, False, 1, DUMMY_TIME)
6478
row = file_info.to_row(is_long_form=False, is_absolute=False)
6579
assert len(row) == 1
66-
assert TEST_DBFS_PATH.basename == row[0]
80+
assert TEST_DBFS_PATH1.basename == row[0]
6781

6882
def test_to_row_long_form_not_absolute(self):
69-
file_info = api.FileInfo(TEST_DBFS_PATH, False, 1, DUMMY_TIME)
83+
file_info = api.FileInfo(TEST_DBFS_PATH1, False, 1, DUMMY_TIME)
7084
row = file_info.to_row(is_long_form=True, is_absolute=False)
7185
assert len(row) == 4
7286
assert row[0] == 'file'
7387
assert row[1] == 1
74-
assert TEST_DBFS_PATH.basename == row[2]
88+
assert TEST_DBFS_PATH1.basename == row[2]
7589

7690
def test_from_json(self):
77-
file_info = api.FileInfo.from_json(TEST_FILE_JSON)
78-
assert file_info.dbfs_path == TEST_DBFS_PATH
91+
file_info = api.FileInfo.from_json(TEST_FILE_JSON1)
92+
assert file_info.dbfs_path == TEST_DBFS_PATH1
7993
assert not file_info.is_dir
8094
assert file_info.file_size == 1
8195

@@ -89,41 +103,52 @@ def dbfs_api():
89103

90104

91105
class TestDbfsApi(object):
106+
def test_list_files_recursive(self, dbfs_api):
107+
json = {
108+
'files': [TEST_FILE_JSON1, TEST_DIR_JSON, TEST_FILE_JSON2]
109+
}
110+
dbfs_api.client.list.return_value = json
111+
files = dbfs_api.list_files("dbfs:/")
112+
113+
assert len(files) == 2
114+
assert TEST_FILE_INFO0 == files[0]
115+
assert TEST_FILE_INFO1 == files[1]
116+
92117
def test_list_files_exists(self, dbfs_api):
93118
json = {
94-
'files': [TEST_FILE_JSON]
119+
'files': [TEST_FILE_JSON1]
95120
}
96121
dbfs_api.client.list.return_value = json
97-
files = dbfs_api.list_files(TEST_DBFS_PATH)
122+
files = dbfs_api.list_files(TEST_DBFS_PATH1, is_recursive=True)
98123

99124
assert len(files) == 1
100-
assert TEST_FILE_INFO == files[0]
125+
assert TEST_FILE_INFO0 == files[0]
101126

102127
def test_list_files_does_not_exist(self, dbfs_api):
103128
json = {}
104129
dbfs_api.client.list.return_value = json
105-
files = dbfs_api.list_files(TEST_DBFS_PATH)
130+
files = dbfs_api.list_files(TEST_DBFS_PATH1)
106131

107132
assert len(files) == 0
108133

109134
def test_file_exists_true(self, dbfs_api):
110-
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
111-
assert dbfs_api.file_exists(TEST_DBFS_PATH)
135+
dbfs_api.client.get_status.return_value = TEST_FILE_JSON1
136+
assert dbfs_api.file_exists(TEST_DBFS_PATH1)
112137

113138
def test_file_exists_false(self, dbfs_api):
114139
exception = get_resource_does_not_exist_exception()
115140
dbfs_api.client.get_status = mock.Mock(side_effect=exception)
116-
assert not dbfs_api.file_exists(TEST_DBFS_PATH)
141+
assert not dbfs_api.file_exists(TEST_DBFS_PATH1)
117142

118143
def test_get_status(self, dbfs_api):
119-
dbfs_api.client.get_status.return_value = TEST_FILE_JSON
120-
assert dbfs_api.get_status(TEST_DBFS_PATH) == TEST_FILE_INFO
144+
dbfs_api.client.get_status.return_value = TEST_FILE_JSON1
145+
assert dbfs_api.get_status(TEST_DBFS_PATH1) == TEST_FILE_INFO0
121146

122147
def test_get_status_fail(self, dbfs_api):
123148
exception = get_resource_does_not_exist_exception()
124149
dbfs_api.client.get_status = mock.Mock(side_effect=exception)
125150
with pytest.raises(exception.__class__):
126-
dbfs_api.get_status(TEST_DBFS_PATH)
151+
dbfs_api.get_status(TEST_DBFS_PATH1)
127152

128153
def test_put_file(self, dbfs_api, tmpdir):
129154
test_file_path = os.path.join(tmpdir.strpath, 'test')
@@ -133,7 +158,7 @@ def test_put_file(self, dbfs_api, tmpdir):
133158
api_mock = dbfs_api.client
134159
test_handle = 0
135160
api_mock.create.return_value = {'handle': test_handle}
136-
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True)
161+
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH1, True)
137162

138163
# Should not call add-block since file is < 2GB
139164
assert api_mock.add_block.call_count == 0
@@ -148,7 +173,7 @@ def test_put_large_file(self, dbfs_api, tmpdir):
148173
dbfs_api.MULTIPART_UPLOAD_LIMIT = 2
149174
test_handle = 0
150175
api_mock.create.return_value = {'handle': test_handle}
151-
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH, True)
176+
dbfs_api.put_file(test_file_path, TEST_DBFS_PATH1, True)
152177
assert api_mock.add_block.call_count == 1
153178
assert test_handle == api_mock.add_block.call_args[0][0]
154179
assert b64encode(b'test').decode() == api_mock.add_block.call_args[0][1]
@@ -160,18 +185,18 @@ def test_get_file_check_overwrite(self, dbfs_api, tmpdir):
160185
with open(test_file_path, 'w') as f:
161186
f.write('test')
162187
with pytest.raises(LocalFileExistsException):
163-
dbfs_api.get_file(TEST_DBFS_PATH, test_file_path, False)
188+
dbfs_api.get_file(TEST_DBFS_PATH1, test_file_path, False)
164189

165190
def test_get_file(self, dbfs_api, tmpdir):
166191
api_mock = dbfs_api.client
167-
api_mock.get_status.return_value = TEST_FILE_JSON
192+
api_mock.get_status.return_value = TEST_FILE_JSON1
168193
api_mock.read.return_value = {
169194
'bytes_read': 1,
170195
'data': b64encode(b'x'),
171196
}
172197

173198
test_file_path = os.path.join(tmpdir.strpath, 'test')
174-
dbfs_api.get_file(TEST_DBFS_PATH, test_file_path, True)
199+
dbfs_api.get_file(TEST_DBFS_PATH1, test_file_path, True)
175200

176201
with open(test_file_path, 'r') as f:
177202
assert f.read() == 'x'

0 commit comments

Comments
 (0)