66import os
77import tempfile
88import uuid
9+ from typing import Optional
910
1011import requests
1112
@@ -26,7 +27,12 @@ def read_file(blobpath: str) -> bytes:
2627 return resp .content
2728
2829
29- def read_file_cached (blobpath : str ) -> bytes :
30+ def check_hash (data : bytes , hash : str ) -> bool :
31+ data_hash = hashlib .sha256 (data ).hexdigest ()
32+ return data_hash == hash
33+
34+
35+ def read_file_cached (blobpath : str , expected_hash : Optional [str ]= None ) -> bytes :
3036 user_specified_cache = True
3137 if "TIKTOKEN_CACHE_DIR" in os .environ :
3238 cache_dir = os .environ ["TIKTOKEN_CACHE_DIR" ]
@@ -45,9 +51,20 @@ def read_file_cached(blobpath: str) -> bytes:
4551 cache_path = os .path .join (cache_dir , cache_key )
4652 if os .path .exists (cache_path ):
4753 with open (cache_path , "rb" ) as f :
48- return f .read ()
54+ data = f .read ()
55+ if expected_hash and not check_hash (data , expected_hash ):
56+ raise ValueError (
57+ f"Hash mismatch for cached data from { blobpath } (expected { expected_hash } ). "
58+ f"Please delete the cache file at { cache_path } and try again."
59+ )
60+ return data
4961
5062 contents = read_file (blobpath )
63+ if expected_hash and not check_hash (contents , expected_hash ):
64+ raise ValueError (
65+ f"Hash mismatch for data downloaded from { blobpath } (expected { expected_hash } ). "
66+ f"This may indicate a corrupted download. Please try again."
67+ )
5168
5269 try :
5370 os .makedirs (cache_dir , exist_ok = True )
@@ -64,7 +81,7 @@ def read_file_cached(blobpath: str) -> bytes:
6481
6582
6683def data_gym_to_mergeable_bpe_ranks (
67- vocab_bpe_file : str , encoder_json_file : str
84+ vocab_bpe_file : str , encoder_json_file : str , vocab_bpe_hash : Optional [ str ] = None , encoder_json_hash : Optional [ str ] = None
6885) -> dict [bytes , int ]:
6986 # NB: do not add caching to this function
7087 rank_to_intbyte = [b for b in range (2 ** 8 ) if chr (b ).isprintable () and chr (b ) != " " ]
@@ -79,7 +96,7 @@ def data_gym_to_mergeable_bpe_ranks(
7996 assert len (rank_to_intbyte ) == 2 ** 8
8097
8198 # vocab_bpe contains the merges along with associated ranks
82- vocab_bpe_contents = read_file_cached (vocab_bpe_file ).decode ()
99+ vocab_bpe_contents = read_file_cached (vocab_bpe_file , vocab_bpe_hash ).decode ()
83100 bpe_merges = [tuple (merge_str .split ()) for merge_str in vocab_bpe_contents .split ("\n " )[1 :- 1 ]]
84101
85102 def decode_data_gym (value : str ) -> bytes :
@@ -96,7 +113,7 @@ def decode_data_gym(value: str) -> bytes:
96113 # check that the encoder file matches the merges file
97114 # this sanity check is important since tiktoken assumes that ranks are ordered the same
98115 # as merge priority
99- encoder_json = json .loads (read_file_cached (encoder_json_file ))
116+ encoder_json = json .loads (read_file_cached (encoder_json_file , encoder_json_hash ))
100117 encoder_json_loaded = {decode_data_gym (k ): v for k , v in encoder_json .items ()}
101118 # drop these two special tokens if present, since they're not mergeable bpe tokens
102119 encoder_json_loaded .pop (b"<|endoftext|>" , None )
@@ -118,9 +135,9 @@ def dump_tiktoken_bpe(bpe_ranks: dict[bytes, int], tiktoken_bpe_file: str) -> No
118135 f .write (base64 .b64encode (token ) + b" " + str (rank ).encode () + b"\n " )
119136
120137
121- def load_tiktoken_bpe (tiktoken_bpe_file : str ) -> dict [bytes , int ]:
138+ def load_tiktoken_bpe (tiktoken_bpe_file : str , expected_hash : Optional [ str ] = None ) -> dict [bytes , int ]:
122139 # NB: do not add caching to this function
123- contents = read_file_cached (tiktoken_bpe_file )
140+ contents = read_file_cached (tiktoken_bpe_file , expected_hash )
124141 return {
125142 base64 .b64decode (token ): int (rank )
126143 for token , rank in (line .split () for line in contents .splitlines () if line )
0 commit comments