1+ import torch
2+ import pytest
3+ from vector_quantize_pytorch import LFQ
4+ import math
5+ """
6+ testing_strategy:
7+ subdivisions: using masks, using frac_per_sample_entropy < 1
8+ """
9+
10+ torch .manual_seed (0 )
11+
12+ @pytest .mark .parametrize ('frac_per_sample_entropy' , (1. , 0.5 ))
13+ @pytest .mark .parametrize ('mask' , (torch .tensor ([False , False ]),
14+ torch .tensor ([True , False ]),
15+ torch .tensor ([True , True ])))
16+ def test_masked_lfq (
17+ frac_per_sample_entropy ,
18+ mask
19+ ):
20+ # you can specify either dim or codebook_size
21+ # if both specified, will be validated against each other
22+
23+ quantizer = LFQ (
24+ codebook_size = 65536 , # codebook size, must be a power of 2
25+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
26+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
27+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
28+ frac_per_sample_entropy = frac_per_sample_entropy
29+ )
30+
31+ image_feats = torch .randn (2 , 16 , 32 , 32 )
32+
33+ ret , loss_breakdown = quantizer (image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask ) # you may want to experiment with temperature
34+
35+ quantized , indices , _ = ret
36+ assert (quantized == quantizer .indices_to_codes (indices )).all ()
37+
38+ @pytest .mark .parametrize ('frac_per_sample_entropy' , (0.1 ,))
39+ @pytest .mark .parametrize ('iters' , (10 ,))
40+ @pytest .mark .parametrize ('mask' , (None , torch .tensor ([True , False ])))
41+ def test_lfq_bruteforce_frac_per_sample_entropy (frac_per_sample_entropy , iters , mask ):
42+ image_feats = torch .randn (2 , 16 , 32 , 32 )
43+
44+ full_per_sample_entropy_quantizer = LFQ (
45+ codebook_size = 65536 , # codebook size, must be a power of 2
46+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
47+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
48+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
49+ frac_per_sample_entropy = 1
50+ )
51+
52+ partial_per_sample_entropy_quantizer = LFQ (
53+ codebook_size = 65536 , # codebook size, must be a power of 2
54+ dim = 16 , # this is the input feature dimension, defaults to log2(codebook_size) if not defined
55+ entropy_loss_weight = 0.1 , # how much weight to place on entropy loss
56+ diversity_gamma = 1. , # within entropy loss, how much weight to give to diversity
57+ frac_per_sample_entropy = frac_per_sample_entropy
58+ )
59+
60+ ret , loss_breakdown = full_per_sample_entropy_quantizer (
61+ image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask )
62+ true_per_sample_entropy = loss_breakdown .per_sample_entropy
63+
64+ per_sample_losses = torch .zeros (iters )
65+ for iter in range (iters ):
66+ ret , loss_breakdown = partial_per_sample_entropy_quantizer (
67+ image_feats , inv_temperature = 100. , return_loss_breakdown = True , mask = mask ) # you may want to experiment with temperature
68+
69+ quantized , indices , _ = ret
70+ assert (quantized == partial_per_sample_entropy_quantizer .indices_to_codes (indices )).all ()
71+ per_sample_losses [iter ] = loss_breakdown .per_sample_entropy
72+ # 95% confidence interval
73+ assert abs (per_sample_losses .mean () - true_per_sample_entropy ) \
74+ < (1.96 * (per_sample_losses .std () / math .sqrt (iters )))
75+
76+ print ("difference: " , abs (per_sample_losses .mean () - true_per_sample_entropy ))
77+ print ("std error:" , (1.96 * (per_sample_losses .std () / math .sqrt (iters ))))
0 commit comments