Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions awscli/botocore/httpchecksum.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def update(self, chunk):
def digest(self):
return self._int_crc32.to_bytes(4, byteorder="big")

@property
def int_crc(self):
return self._int_crc32


class CrtCrc32Checksum(BaseChecksum):
# Note: This class is only used if the CRT is available
Expand All @@ -88,6 +92,10 @@ def update(self, chunk):
def digest(self):
return self._int_crc32.to_bytes(4, byteorder="big")

@property
def int_crc(self):
return self._int_crc32


class CrtCrc32cChecksum(BaseChecksum):
# Note: This class is only used if the CRT is available
Expand All @@ -101,6 +109,10 @@ def update(self, chunk):
def digest(self):
return self._int_crc32c.to_bytes(4, byteorder="big")

@property
def int_crc(self):
return self._int_crc32c


class CrtCrc64NvmeChecksum(BaseChecksum):
# Note: This class is only used if the CRT is available
Expand All @@ -114,6 +126,10 @@ def update(self, chunk):
def digest(self):
return self._int_crc64nvme.to_bytes(8, byteorder="big")

@property
def int_crc(self):
return self._int_crc64nvme


class Sha1Checksum(BaseChecksum):
def __init__(self):
Expand Down
163 changes: 163 additions & 0 deletions awscli/s3transfer/checksums.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import base64
from functools import cached_property

from botocore.httpchecksum import (
CrtCrc32cChecksum,
CrtCrc32Checksum,
CrtCrc64NvmeChecksum,
)


class StreamingChecksumBody:
def __init__(self, stream, starting_index, checksum_validator):
self._stream = stream
self._starting_index = starting_index
self._checksum = _CRC_CHECKSUM_CLS[
checksum_validator.checksum_algorithm
]()
self._checksum_validator = checksum_validator

def read(self, *args, **kwargs):
value = self._stream.read(*args, **kwargs)
self._checksum.update(value)
if not value:
self._checksum_validator.set_part_checksums(
self._starting_index, self._checksum.int_crc
)
return value


class ChecksumValidator:
def __init__(self, stored_checksum, content_length):
self.checksum_algorithm = list(stored_checksum.keys())[0]
self._checksum_value = stored_checksum[self.checksum_algorithm]
self._combine_function = _CRC_CHECKSUM_TO_COMBINE_FUNCTION[
self.checksum_algorithm
]
self._part_checksums = None
self._calculated_checksum = None
self._content_length = content_length

@cached_property
def calculated_checksum(self):
if self._calculated_checksum is None:
self._combine_part_checksums()
return self._calculated_checksum

def set_part_checksums(self, offset, checksum):
if self._part_checksums is None:
self._part_checksums = {}
self._part_checksums[offset] = checksum

def _combine_part_checksums(self):
if self._part_checksums is None:
return
sorted_keys = sorted(self._part_checksums.keys())
combined = self._part_checksums[sorted_keys[0]]
for i, offset in enumerate(sorted_keys[1:]):
part_checksum = self._part_checksums[offset]
if i + 1 == len(sorted_keys) - 1:
next_offset = self._content_length
else:
next_offset = sorted_keys[i + 2]
offset_len = next_offset - offset
combined = self._combine_function(
combined, part_checksum, offset_len
)
self._calculated_checksum = base64.b64encode(
combined.to_bytes(4, byteorder='big')
).decode('ascii')

def validate(self):
if not self._checksum_value:
return
if self.calculated_checksum != self._checksum_value:
raise Exception(
f"stored: {self._checksum_value} != calculated: {self.calculated_checksum}"
)


def combine_crc32(crc1, crc2, len2):
"""
Combine two CRC32 checksums computed with binascii.crc32.

This implementation follows the algorithm used in zlib's crc32_combine.

Args:
crc1: CRC32 checksum of the first data block (from binascii.crc32)
crc2: CRC32 checksum of the second data block (from binascii.crc32)
len2: Length in bytes of the second data block

Returns:
Combined CRC32 checksum as if the two blocks were concatenated
"""

# CRC-32 polynomial in reversed bit order
POLY = 0xEDB88320

def gf2_matrix_times(mat, vec):
"""Multiply matrix by vector over GF(2)"""
result = 0
for i in range(32):
if vec & (1 << i):
result ^= mat[i]
return result & 0xFFFFFFFF

def gf2_matrix_square(square, mat):
"""Square matrix over GF(2)"""
for n in range(32):
square[n] = gf2_matrix_times(mat, mat[n])

# Create initial CRC matrix for 1 bit
odd = [0] * 32
even = [0] * 32

# Build odd matrix (for 1 bit shift)
odd[0] = POLY
for n in range(1, 32):
odd[n] = 1 << (n - 1)

# Square to get even matrix (for 2 bit shift), then keep squaring
gf2_matrix_square(even, odd)
gf2_matrix_square(odd, even)

# Process len2 bytes (8 * len2 bits)
length = len2

# Process chunks of 3 bits at a time (since we have matrices for 4 and 8 bit shifts)
while length != 0:
# Square matrices to advance to next power of 2
gf2_matrix_square(even, odd)
if length & 1:
crc1 = gf2_matrix_times(even, crc1)
length >>= 1

if length == 0:
break

gf2_matrix_square(odd, even)
if length & 1:
crc1 = gf2_matrix_times(odd, crc1)
length >>= 1

# XOR the two CRCs
crc1 ^= crc2

return crc1 & 0xFFFFFFFF


_CRC_CHECKSUM_TO_COMBINE_FUNCTION = {
"ChecksumCRC64NVME": None,
"ChecksumCRC32C": None,
"ChecksumCRC32": combine_crc32,
}


_CRC_CHECKSUM_CLS = {
"ChecksumCRC64NVME": CrtCrc64NvmeChecksum,
"ChecksumCRC32C": CrtCrc32cChecksum,
"ChecksumCRC32": CrtCrc32Checksum,
}


CRC_CHECKSUMS = _CRC_CHECKSUM_TO_COMBINE_FUNCTION.keys()
81 changes: 80 additions & 1 deletion awscli/s3transfer/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
import threading

from botocore.exceptions import ClientError
from s3transfer.checksums import (
CRC_CHECKSUMS,
ChecksumValidator,
StreamingChecksumBody,
)
from s3transfer.compat import seekable
from s3transfer.exceptions import RetriesExceededError, S3DownloadFailedError
from s3transfer.futures import IN_MEMORY_DOWNLOAD_TAG
Expand Down Expand Up @@ -137,6 +142,14 @@ def get_final_io_task(self):
"""
raise NotImplementedError('must implement get_final_io_task()')

def get_validate_checksum_task(self, checksum_validator):
return ValidateChecksumTask(
transfer_coordinator=self._transfer_coordinator,
main_kwargs={
'checksum_validator': checksum_validator,
},
)

def _get_fileobj_from_filename(self, filename):
f = DeferredOpenFile(
filename, mode='wb', open_function=self._osutil.open
Expand Down Expand Up @@ -350,10 +363,12 @@ def _submit(
if (
transfer_future.meta.size is None
or transfer_future.meta.etag is None
or transfer_future.meta.stored_checksum is None
):
response = client.head_object(
Bucket=transfer_future.meta.call_args.bucket,
Key=transfer_future.meta.call_args.key,
ChecksumMode="ENABLED",
**transfer_future.meta.call_args.extra_args,
)
# If a size was not provided figure out the size for the
Expand All @@ -364,6 +379,8 @@ def _submit(
# Provide an etag to ensure a stored object is not modified
# during a multipart download.
transfer_future.meta.provide_object_etag(response.get('ETag'))
# Provide checksum
self._provide_checksum_to_meta(response, transfer_future.meta)

download_output_manager = self._get_download_output_manager_cls(
transfer_future, osutil
Expand Down Expand Up @@ -480,6 +497,20 @@ def _submit_ranged_download_request(
download_output_manager, io_executor
)
)
#
checksum_validator = None
if transfer_future.meta.stored_checksum:
checksum_validator = ChecksumValidator(
transfer_future.meta.stored_checksum,
transfer_future.meta.size,
)
validate_checksum_invoker = CountCallbackInvoker(
self._get_validate_checksum_task(
download_output_manager,
io_executor,
checksum_validator,
)
)
for i in range(num_parts):
# Calculate the range parameter
range_parameter = calculate_range_parameter(
Expand All @@ -494,6 +525,7 @@ def _submit_ranged_download_request(
extra_args['IfMatch'] = transfer_future.meta.etag
extra_args.update(call_args.extra_args)
finalize_download_invoker.increment()
validate_checksum_invoker.increment()
# Submit the ranged downloads
self._transfer_coordinator.submit(
request_executor,
Expand All @@ -511,13 +543,36 @@ def _submit_ranged_download_request(
'download_output_manager': download_output_manager,
'io_chunksize': config.io_chunksize,
'bandwidth_limiter': bandwidth_limiter,
'checksum_validator': checksum_validator,
},
done_callbacks=[finalize_download_invoker.decrement],
done_callbacks=[
validate_checksum_invoker.decrement,
finalize_download_invoker.decrement,
],
),
tag=get_object_tag,
)
validate_checksum_invoker.finalize()
finalize_download_invoker.finalize()

def _get_validate_checksum_task(
self,
download_manager,
io_executor,
checksum_validator,
):
if checksum_validator is None:
task = CompleteDownloadNOOPTask(
transfer_coordinator=self._transfer_coordinator,
)
else:
task = download_manager.get_validate_checksum_task(
checksum_validator,
)
return FunctionContainer(
self._transfer_coordinator.submit, io_executor, task
)

def _get_final_io_task_submission_callback(
self, download_manager, io_executor
):
Expand All @@ -536,6 +591,18 @@ def _calculate_range_param(self, part_size, part_index, num_parts):
range_param = f'bytes={start_range}-{end_range}'
return range_param

def _provide_checksum_to_meta(self, response, transfer_meta):
checksum_type = response.get("ChecksumType")
if not checksum_type or checksum_type != "FULL_OBJECT":
transfer_meta.provide_stored_checksum({})
for crc_checksum in CRC_CHECKSUMS:
if checksum_value := response.get(crc_checksum):
transfer_meta.provide_stored_checksum(
{crc_checksum: checksum_value}
)
return
transfer_meta.provide_stored_checksum({})


class GetObjectTask(Task):
def _main(
Expand All @@ -551,6 +618,7 @@ def _main(
io_chunksize,
start_index=0,
bandwidth_limiter=None,
checksum_validator=None,
):
"""Downloads an object and places content into io queue

Expand Down Expand Up @@ -580,6 +648,12 @@ def _main(
streaming_body = StreamReaderProgress(
response['Body'], callbacks
)
if checksum_validator:
streaming_body = StreamingChecksumBody(
streaming_body,
current_index,
checksum_validator,
)
if bandwidth_limiter:
streaming_body = (
bandwidth_limiter.get_bandwith_limited_stream(
Expand Down Expand Up @@ -831,3 +905,8 @@ def request_writes(self, offset, data):
del self._pending_offsets[next_write_offset]
self._next_offset += len(next_write)
return writes


class ValidateChecksumTask(Task):
def _main(self, checksum_validator):
checksum_validator.validate()
8 changes: 8 additions & 0 deletions awscli/s3transfer/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def __init__(self, call_args=None, transfer_id=None):
self._size = None
self._user_context = {}
self._etag = None
self._stored_checksum = None

@property
def call_args(self):
Expand All @@ -154,6 +155,13 @@ def etag(self):
"""The etag of the stored object for validating multipart downloads"""
return self._etag

@property
def stored_checksum(self):
return self._stored_checksum

def provide_stored_checksum(self, checksum):
self._stored_checksum = checksum

def provide_transfer_size(self, size):
"""A method to provide the size of a transfer request

Expand Down
Loading