Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/bioclip/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def create_image_features_for_image(self, image: str | PIL.Image.Image, normaliz

def create_probabilities(self, img_features: torch.Tensor,
txt_features: torch.Tensor) -> dict[str, torch.Tensor]:
logits = (self.model.logit_scale.exp() * img_features @ txt_features)
logits = (self.model.logit_scale * img_features @ txt_features)
return F.softmax(logits, dim=1)

def create_probabilities_for_images(self, images: List[str] | List[PIL.Image.Image],
Expand Down
60 changes: 60 additions & 0 deletions tests/test_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bioclip import BIOCLIP_V2_MODEL_STR, BIOCLIP_V1_MODEL_STR
import os
import torch
import torch.nn.functional as F
import pandas as pd
import PIL.Image

Expand Down Expand Up @@ -377,6 +378,65 @@ def test_format_species_probs_too_few_species(self):
top_probs = classifier.format_species_probs(EXAMPLE_CAT_IMAGE, probs, k=5)


class TestCreateProbabilities(unittest.TestCase):
def _make_classifier_with_logit_scale(self, logit_scale_value):
"""Helper: build a BaseClassifier whose model.logit_scale is a plain tensor."""
mock_model = Mock()
mock_model.logit_scale = torch.tensor(logit_scale_value)
mock_preprocess = Mock(side_effect=lambda img: torch.zeros(3, 224, 224))

with patch('open_clip.create_model_from_pretrained', return_value=(mock_model, mock_preprocess)), \
patch('open_clip.list_pretrained_tags_by_model', return_value=[]), \
patch('torch.compile', side_effect=lambda m, **kw: m):
classifier = BaseClassifier.__new__(BaseClassifier)
classifier.device = 'cpu'
classifier.model = mock_model
classifier.preprocess = mock_preprocess
classifier.recorder = None
return classifier

def test_create_probabilities_uses_logit_scale_directly(self):
"""create_probabilities must multiply by logit_scale without an extra exp()."""
# Use a large scale value (e.g. 100) to make the distinction obvious:
# if exp() were applied the scale would become exp(100) ≈ 2.7e43, causing
# all probability mass to collapse to a single class.
scale = 100.0
classifier = self._make_classifier_with_logit_scale(scale)

# Orthogonal unit vectors: img is similar to class 0 only
img_features = torch.tensor([[1.0, 0.0]])
txt_features = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) # shape (2, 2)

probs = classifier.create_probabilities(img_features, txt_features.T)

# Expected with direct logit_scale (no exp):
expected_logits = torch.tensor(scale) * img_features @ txt_features.T
expected_probs = F.softmax(expected_logits, dim=1)

self.assertTrue(torch.allclose(probs, expected_probs),
f"Probabilities {probs} don't match expected {expected_probs}")
# With scale=100 and unit vectors the first class should still dominate
self.assertGreater(probs[0, 0].item(), 0.99)

def test_create_probabilities_exp_would_differ(self):
"""Confirm that applying exp() to logit_scale gives a meaningfully different result."""
scale = 4.6052 # log(100)
classifier = self._make_classifier_with_logit_scale(scale)

img_features = torch.tensor([[1.0, 0.0]])
txt_features = torch.tensor([[1.0, 0.0], [0.0, 1.0]])

probs = classifier.create_probabilities(img_features, txt_features.T)

# What a double-exp implementation would compute
wrong_logits = torch.tensor(scale).exp() * img_features @ txt_features.T
wrong_probs = F.softmax(wrong_logits, dim=1)

# The direct-scale result should be different from the double-exp result
self.assertFalse(torch.allclose(probs, wrong_probs),
"create_probabilities should NOT match double-exp result")


class TestEmbed(unittest.TestCase):
def test_get_image_features(self):
classifier = TreeOfLifeClassifier(device='cpu')
Expand Down
Loading