Skip to content
Open
26 changes: 26 additions & 0 deletions Dockerfile.hpu
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Use the official Gaudi Docker image with PyTorch
FROM vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.4.0:latest

# Set environment variables for Habana
ENV HABANA_VISIBLE_DEVICES=all
ENV OMPI_MCA_btl_vader_single_copy_mechanism=none
ENV PT_HPU_LAZY_ACC_PAR_MODE=0
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=1

# Set timezone to UTC and install essential packages
ENV DEBIAN_FRONTEND="noninteractive" TZ=Etc/UTC
RUN apt-get update && apt-get install -y \
tzdata \
python3-pip \
&& rm -rf /var/lib/apt/lists/*

COPY . /workspace/clip
WORKDIR /workspace/clip

# Copy HPU requirements
COPY requirements_hpu.txt /workspace/requirements_hpu.txt

# Install Python packages
RUN pip install --upgrade pip \
&& pip install -r requirements_hpu.txt \
&& pip install -e .
41 changes: 38 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
from clip.utils import get_device_initial

device = get_device_initial() # "HPU" if using Intel® Gaudi® HPU, "cuda" if using CUDA GPU, "cpu" otherwise
model, preprocess = clip.load("ViT-B/32", device=device)

image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
Expand Down Expand Up @@ -94,8 +96,10 @@ import clip
import torch
from torchvision.datasets import CIFAR100

from clip.utils import get_device_initial

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device_initial()
model, preprocess = clip.load('ViT-B/32', device)

# Download the dataset
Expand Down Expand Up @@ -153,8 +157,10 @@ from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR100
from tqdm import tqdm

from clip.utils import get_device_initial

# Load the model
device = "cuda" if torch.cuda.is_available() else "cpu"
device = get_device_initial()
model, preprocess = clip.load('ViT-B/32', device)

# Load the dataset
Expand Down Expand Up @@ -193,6 +199,35 @@ print(f"Accuracy = {accuracy:.3f}")
Note that the `C` value should be determined via a hyperparameter sweep using a validation split.


## Intel® Gaudi® HPU Usage

### Build the Docker Image
To use Intel® Gaudi® HPU for running this notebook, start by building a Docker image with the appropriate environment setup.

```bash
docker build -t clip_hpu:latest -f Dockerfile.hpu .
```

In the `Dockerfile.hpu`, we use the `vault.habana.ai/gaudi-docker/1.18.0/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest` base image. Ensure that the version matches your setup.
See the [PyTorch Docker Images for the Intel® Gaudi® Accelerator](https://developer.habana.ai/catalog/pytorch-container/) for more information.

### Run the Container

```bash
docker run -it --runtime=habana clip_hpu:latest
```

### Python Usage with Intel® Gaudi® HPU

You do not need to change the code to leverage Intel® Gaudi® HPU. The `get_device_initial()` function will automatically detect the correct device and return the appropriate device name. So no changes are required.

### Run the Tests

```bash
pytest
```
This will run the tests and verify that the model is working correctly.

## See Also

* [OpenCLIP](https://github.com/mlfoundations/open_clip): includes larger and independently trained CLIP models up to ViT-G/14
Expand Down
68 changes: 55 additions & 13 deletions clip/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@

from .model import build_model
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
from .utils import get_device_initial

try:
from torchvision.transforms import InterpolationMode

BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
Expand Down Expand Up @@ -51,13 +53,24 @@ def _download(url: str, root: str):
raise RuntimeError(f"{download_target} exists and is not a regular file")

if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
if (
hashlib.sha256(open(download_target, "rb").read()).hexdigest()
== expected_sha256
):
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
warnings.warn(
f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
)

with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
with tqdm(
total=int(source.info().get("Content-Length")),
ncols=80,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as loop:
while True:
buffer = source.read(8192)
if not buffer:
Expand Down Expand Up @@ -91,7 +104,12 @@ def available_models() -> List[str]:
return list(_MODELS.keys())


def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
def load(
name: str,
device: Union[str, torch.device] = get_device_initial(),
jit: bool = False,
download_root: str = None,
):
"""Load a CLIP model

Parameters
Expand All @@ -100,7 +118,7 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict

device : Union[str, torch.device]
The device to put the loaded model
The device to put the loaded model, by default it uses the device returned by `clip.get_device_initial()`

jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
Expand All @@ -123,10 +141,12 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")

with open(model_path, 'rb') as opened_file:
with open(model_path, "rb") as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
model = torch.jit.load(
opened_file, map_location=device if jit else "cpu"
).eval()
state_dict = None
except RuntimeError:
# loading saved state dict
Expand All @@ -136,13 +156,25 @@ def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_a
state_dict = torch.load(opened_file, map_location="cpu")

if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
model = build_model(state_dict or model.state_dict())

if str(device) == "hpu":
from habana_frameworks.torch.utils.library_loader import load_habana_module

load_habana_module()
if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model)
model = model.eval().to(torch.device(device))
else:
model = model.to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)

# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device("cpu" if device == "hpu" else device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]

def _node_get(node: torch._C.Node, key: str):
Expand Down Expand Up @@ -171,9 +203,11 @@ def patch_device(module):
patch_device(model.encode_image)
patch_device(model.encode_text)

# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
# patch dtype to float32 on CPU, HPU
if str(device) in ["cpu", "hpu"]:
float_holder = torch.jit.trace(
lambda: torch.ones([]).float(), example_inputs=[]
)
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()

Expand All @@ -199,10 +233,18 @@ def patch_float(module):

model.float()

if str(device) == "hpu":
if torch.hpu.is_available():
from habana_frameworks.torch.hpu import wrap_in_hpu_graph

model = wrap_in_hpu_graph(model)
model = model.eval().to(torch.device(device))
return model, _transform(model.input_resolution.item())


def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> Union[torch.IntTensor, torch.LongTensor]:
def tokenize(
texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False
) -> Union[torch.IntTensor, torch.LongTensor]:
"""
Returns the tokenized representation of given input string(s)

Expand Down
30 changes: 30 additions & 0 deletions clip/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import importlib.util

import torch


def get_device_initial(preferred_device=None):
"""
Determine the appropriate device to use (cuda, hpu, or cpu).
Args:
preferred_device (str): User-preferred device ('cuda', 'hpu', or 'cpu').

Returns:
str: Device string ('cuda', 'hpu', or 'cpu').
"""
# Check for HPU support
if importlib.util.find_spec("habana_frameworks") is not None:
from habana_frameworks.torch.utils.library_loader import load_habana_module

load_habana_module()
if torch.hpu.is_available():
if preferred_device == "hpu" or preferred_device is None:
return "hpu"

# Check for CUDA (GPU support)
if torch.cuda.is_available():
if preferred_device == "cuda" or preferred_device is None:
return "cuda"

# Default to CPU
return "cpu"
3 changes: 3 additions & 0 deletions requirements_hpu.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-r requirements.txt
optimum-habana==1.14.1
pytest
22 changes: 21 additions & 1 deletion tests/test_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import pytest
import torch
from PIL import Image
import habana_frameworks.torch

import clip


@pytest.mark.parametrize('model_name', clip.available_models())
@pytest.mark.parametrize("model_name", clip.available_models())
def test_consistency(model_name):
device = "cpu"
jit_model, transform = clip.load(model_name, device=device, jit=True)
Expand All @@ -23,3 +24,22 @@ def test_consistency(model_name):
py_probs = logits_per_image.softmax(dim=-1).cpu().numpy()

assert np.allclose(jit_probs, py_probs, atol=0.01, rtol=0.1)


@pytest.mark.parametrize("model_name", clip.available_models())
def test_hpu_support(model_name):
devices = ["hpu", "cpu"]
all_probs = []
for device in devices:
print(f"=== Testing {model_name} on {device} ===")
model, transform = clip.load(model_name, device=device, jit=False)

image = transform(Image.open("CLIP.png")).unsqueeze(0).to(device)
text = clip.tokenize(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad():
logits_per_image, _ = model(image, text)
probs = logits_per_image.softmax(dim=-1).cpu().numpy()
all_probs.append(probs)

assert np.allclose(all_probs[0], all_probs[1], atol=0.01, rtol=0.1)