Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,24 @@
import base64
import hashlib
import os
import io
import urllib.parse


def read_file(blobpath: str) -> bytes:
if not blobpath.startswith("http://") and not blobpath.startswith("https://"):
url = urllib.parse.urlparse(blobpath)
if url.scheme is None or url.scheme == "":
with open(blobpath, "rb") as f:
with io.BufferedReader(f) as br:
return br.read()
elif url.scheme in ["http", "https"]:
# avoiding blobfile for public files helps avoid auth issues, like MFA prompts
import requests

resp = requests.get(blobpath)
resp.raise_for_status()
return resp.content
else:
try:
import blobfile
except ImportError as e:
Expand All @@ -16,13 +30,6 @@ def read_file(blobpath: str) -> bytes:
with blobfile.BlobFile(blobpath, "rb") as f:
return f.read()

# avoiding blobfile for public files helps avoid auth issues, like MFA prompts
import requests

resp = requests.get(blobpath)
resp.raise_for_status()
return resp.content


def check_hash(data: bytes, expected_hash: str) -> bool:
actual_hash = hashlib.sha256(data).hexdigest()
Expand Down