diff --git a/examples/gsc/download_preprocessed_data.py b/examples/gsc/download_preprocessed_data.py index 06f9060..cb08161 100644 --- a/examples/gsc/download_preprocessed_data.py +++ b/examples/gsc/download_preprocessed_data.py @@ -67,7 +67,26 @@ def extract_tarball(): with tarfile.open(TARFILEPATH) as tar: # This is slow to count. tot = 42 # len(list(tar.getnames())) - tar.extractall(DATAPATH, + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, DATAPATH, members=tqdm(tardesc="Extracting",total=tot,unit="file",unit_scale="True",leave="False")) members=tqdm(tar, desc="Extracting", total=tot, unit="file", unit_scale=True, leave=False)) diff --git a/examples/gsc/download_raw_data.py b/examples/gsc/download_raw_data.py index 51c5c09..7900695 100644 --- a/examples/gsc/download_raw_data.py +++ b/examples/gsc/download_raw_data.py @@ -71,7 +71,26 @@ def extract_tarball(): with tarfile.open(TARFILEPATH) as tar: # This is slow to count. tot = 64764 # len(list(tar.getnames())) - tar.extractall(EXTRACTPATH, + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, EXTRACTPATH, members=tqdm(tardesc="Extracting",total=tot,unit="file",unit_scale="True",leave="False")) members=tqdm(tar, desc="Extracting", total=tot, unit="file", unit_scale=True, leave=False)) diff --git a/examples/gsc_pretrain/download_data.py b/examples/gsc_pretrain/download_data.py index 51c5c09..7900695 100644 --- a/examples/gsc_pretrain/download_data.py +++ b/examples/gsc_pretrain/download_data.py @@ -71,7 +71,26 @@ def extract_tarball(): with tarfile.open(TARFILEPATH) as tar: # This is slow to count. tot = 64764 # len(list(tar.getnames())) - tar.extractall(EXTRACTPATH, + def is_within_directory(directory, target): + + abs_directory = os.path.abspath(directory) + abs_target = os.path.abspath(target) + + prefix = os.path.commonprefix([abs_directory, abs_target]) + + return prefix == abs_directory + + def safe_extract(tar, path=".", members=None, *, numeric_owner=False): + + for member in tar.getmembers(): + member_path = os.path.join(path, member.name) + if not is_within_directory(path, member_path): + raise Exception("Attempted Path Traversal in Tar File") + + tar.extractall(path, members, numeric_owner=numeric_owner) + + + safe_extract(tar, EXTRACTPATH, members=tqdm(tardesc="Extracting",total=tot,unit="file",unit_scale="True",leave="False")) members=tqdm(tar, desc="Extracting", total=tot, unit="file", unit_scale=True, leave=False))