diff --git a/src/bioclip/predict.py b/src/bioclip/predict.py index 9f5e4fb..b91e4dd 100644 --- a/src/bioclip/predict.py +++ b/src/bioclip/predict.py @@ -239,7 +239,7 @@ def create_image_features_for_image(self, image: str | PIL.Image.Image, normaliz def create_probabilities(self, img_features: torch.Tensor, txt_features: torch.Tensor) -> dict[str, torch.Tensor]: - logits = (self.model.logit_scale.exp() * img_features @ txt_features) + logits = (self.model.logit_scale * img_features @ txt_features) return F.softmax(logits, dim=1) def create_probabilities_for_images(self, images: List[str] | List[PIL.Image.Image], diff --git a/tests/test_predict.py b/tests/test_predict.py index 769998c..7d730dc 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -8,6 +8,7 @@ from bioclip import BIOCLIP_V2_MODEL_STR, BIOCLIP_V1_MODEL_STR import os import torch +import torch.nn.functional as F import pandas as pd import PIL.Image @@ -377,6 +378,65 @@ def test_format_species_probs_too_few_species(self): top_probs = classifier.format_species_probs(EXAMPLE_CAT_IMAGE, probs, k=5) +class TestCreateProbabilities(unittest.TestCase): + def _make_classifier_with_logit_scale(self, logit_scale_value): + """Helper: build a BaseClassifier whose model.logit_scale is a plain tensor.""" + mock_model = Mock() + mock_model.logit_scale = torch.tensor(logit_scale_value) + mock_preprocess = Mock(side_effect=lambda img: torch.zeros(3, 224, 224)) + + with patch('open_clip.create_model_from_pretrained', return_value=(mock_model, mock_preprocess)), \ + patch('open_clip.list_pretrained_tags_by_model', return_value=[]), \ + patch('torch.compile', side_effect=lambda m, **kw: m): + classifier = BaseClassifier.__new__(BaseClassifier) + classifier.device = 'cpu' + classifier.model = mock_model + classifier.preprocess = mock_preprocess + classifier.recorder = None + return classifier + + def test_create_probabilities_uses_logit_scale_directly(self): + """create_probabilities must multiply by logit_scale without an extra exp().""" + # Use a large scale value (e.g. 100) to make the distinction obvious: + # if exp() were applied the scale would become exp(100) ≈ 2.7e43, causing + # all probability mass to collapse to a single class. + scale = 100.0 + classifier = self._make_classifier_with_logit_scale(scale) + + # Orthogonal unit vectors: img is similar to class 0 only + img_features = torch.tensor([[1.0, 0.0]]) + txt_features = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) # shape (2, 2) + + probs = classifier.create_probabilities(img_features, txt_features.T) + + # Expected with direct logit_scale (no exp): + expected_logits = torch.tensor(scale) * img_features @ txt_features.T + expected_probs = F.softmax(expected_logits, dim=1) + + self.assertTrue(torch.allclose(probs, expected_probs), + f"Probabilities {probs} don't match expected {expected_probs}") + # With scale=100 and unit vectors the first class should still dominate + self.assertGreater(probs[0, 0].item(), 0.99) + + def test_create_probabilities_exp_would_differ(self): + """Confirm that applying exp() to logit_scale gives a meaningfully different result.""" + scale = 4.6052 # log(100) + classifier = self._make_classifier_with_logit_scale(scale) + + img_features = torch.tensor([[1.0, 0.0]]) + txt_features = torch.tensor([[1.0, 0.0], [0.0, 1.0]]) + + probs = classifier.create_probabilities(img_features, txt_features.T) + + # What a double-exp implementation would compute + wrong_logits = torch.tensor(scale).exp() * img_features @ txt_features.T + wrong_probs = F.softmax(wrong_logits, dim=1) + + # The direct-scale result should be different from the double-exp result + self.assertFalse(torch.allclose(probs, wrong_probs), + "create_probabilities should NOT match double-exp result") + + class TestEmbed(unittest.TestCase): def test_get_image_features(self): classifier = TreeOfLifeClassifier(device='cpu')