@@ -108,6 +108,7 @@ def __init__(
108
108
sparse : bool = False ,
109
109
device = None ,
110
110
dtype = None ,
111
+ ** kwargs ,
111
112
) -> None :
112
113
factory_kwargs = {"device" : device , "dtype" : dtype }
113
114
# Have to call Module init explicitly in order not to use the Embedding init
@@ -116,6 +117,7 @@ def __init__(
116
117
self .num_embeddings = num_embeddings
117
118
self .embedding_dim = embedding_dim
118
119
self .vsa = vsa
120
+ self .vsa_kwargs = kwargs
119
121
120
122
if padding_idx is not None :
121
123
if padding_idx > 0 :
@@ -135,7 +137,7 @@ def __init__(
135
137
self .sparse = sparse
136
138
137
139
embeddings = functional .empty (
138
- num_embeddings , embedding_dim , self .vsa , ** factory_kwargs
140
+ num_embeddings , embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
139
141
)
140
142
# Have to provide requires grad at the creation of the parameters to
141
143
# prevent errors when instantiating a non-float embedding
@@ -148,7 +150,11 @@ def reset_parameters(self) -> None:
148
150
149
151
with torch .no_grad ():
150
152
embeddings = functional .empty (
151
- self .num_embeddings , self .embedding_dim , self .vsa , ** factory_kwargs
153
+ self .num_embeddings ,
154
+ self .embedding_dim ,
155
+ self .vsa ,
156
+ ** factory_kwargs ,
157
+ ** self .vsa_kwargs ,
152
158
)
153
159
self .weight .copy_ (embeddings )
154
160
@@ -214,6 +220,7 @@ def __init__(
214
220
sparse : bool = False ,
215
221
device = None ,
216
222
dtype = None ,
223
+ ** kwargs ,
217
224
) -> None :
218
225
factory_kwargs = {"device" : device , "dtype" : dtype }
219
226
# Have to call Module init explicitly in order not to use the Embedding init
@@ -222,6 +229,7 @@ def __init__(
222
229
self .num_embeddings = num_embeddings
223
230
self .embedding_dim = embedding_dim
224
231
self .vsa = vsa
232
+ self .vsa_kwargs = kwargs
225
233
226
234
if padding_idx is not None :
227
235
if padding_idx > 0 :
@@ -241,7 +249,7 @@ def __init__(
241
249
self .sparse = sparse
242
250
243
251
embeddings = functional .identity (
244
- num_embeddings , embedding_dim , self .vsa , ** factory_kwargs
252
+ num_embeddings , embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
245
253
)
246
254
# Have to provide requires grad at the creation of the parameters to
247
255
# prevent errors when instantiating a non-float embedding
@@ -254,7 +262,11 @@ def reset_parameters(self) -> None:
254
262
255
263
with torch .no_grad ():
256
264
embeddings = functional .identity (
257
- self .num_embeddings , self .embedding_dim , self .vsa , ** factory_kwargs
265
+ self .num_embeddings ,
266
+ self .embedding_dim ,
267
+ self .vsa ,
268
+ ** factory_kwargs ,
269
+ ** self .vsa_kwargs ,
258
270
)
259
271
self .weight .copy_ (embeddings )
260
272
@@ -266,7 +278,7 @@ def _fill_padding_idx_with_empty(self) -> None:
266
278
if self .padding_idx is not None :
267
279
with torch .no_grad ():
268
280
empty = functional .empty (
269
- 1 , self .embedding_dim , self .vsa , ** factory_kwargs
281
+ 1 , self .embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
270
282
)
271
283
self .weight [self .padding_idx ].copy_ (empty .squeeze (0 ))
272
284
@@ -332,6 +344,7 @@ def __init__(
332
344
sparse : bool = False ,
333
345
device = None ,
334
346
dtype = None ,
347
+ ** kwargs ,
335
348
) -> None :
336
349
factory_kwargs = {"device" : device , "dtype" : dtype }
337
350
# Have to call Module init explicitly in order not to use the Embedding init
@@ -340,6 +353,7 @@ def __init__(
340
353
self .num_embeddings = num_embeddings
341
354
self .embedding_dim = embedding_dim
342
355
self .vsa = vsa
356
+ self .vsa_kwargs = kwargs
343
357
344
358
if padding_idx is not None :
345
359
if padding_idx > 0 :
@@ -359,7 +373,7 @@ def __init__(
359
373
self .sparse = sparse
360
374
361
375
embeddings = functional .random (
362
- num_embeddings , embedding_dim , self .vsa , ** factory_kwargs
376
+ num_embeddings , embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
363
377
)
364
378
# Have to provide requires grad at the creation of the parameters to
365
379
# prevent errors when instantiating a non-float embedding
@@ -372,7 +386,11 @@ def reset_parameters(self) -> None:
372
386
373
387
with torch .no_grad ():
374
388
embeddings = functional .random (
375
- self .num_embeddings , self .embedding_dim , self .vsa , ** factory_kwargs
389
+ self .num_embeddings ,
390
+ self .embedding_dim ,
391
+ self .vsa ,
392
+ ** factory_kwargs ,
393
+ ** self .vsa_kwargs ,
376
394
)
377
395
self .weight .copy_ (embeddings )
378
396
@@ -384,7 +402,7 @@ def _fill_padding_idx_with_empty(self) -> None:
384
402
if self .padding_idx is not None :
385
403
with torch .no_grad ():
386
404
empty = functional .empty (
387
- 1 , self .embedding_dim , self .vsa , ** factory_kwargs
405
+ 1 , self .embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
388
406
)
389
407
self .weight [self .padding_idx ].copy_ (empty .squeeze (0 ))
390
408
@@ -469,6 +487,7 @@ def __init__(
469
487
sparse : bool = False ,
470
488
device = None ,
471
489
dtype = None ,
490
+ ** kwargs ,
472
491
) -> None :
473
492
factory_kwargs = {"device" : device , "dtype" : dtype }
474
493
# Have to call Module init explicitly in order not to use the Embedding init
@@ -477,6 +496,7 @@ def __init__(
477
496
self .num_embeddings = num_embeddings
478
497
self .embedding_dim = embedding_dim
479
498
self .vsa = vsa
499
+ self .vsa_kwargs = kwargs
480
500
self .low = low
481
501
self .high = high
482
502
self .randomness = randomness
@@ -493,6 +513,7 @@ def __init__(
493
513
self .vsa ,
494
514
randomness = randomness ,
495
515
** factory_kwargs ,
516
+ ** self .vsa_kwargs ,
496
517
)
497
518
# Have to provide requires grad at the creation of the parameters to
498
519
# prevent errors when instantiating a non-float embedding
@@ -508,6 +529,7 @@ def reset_parameters(self) -> None:
508
529
self .vsa ,
509
530
randomness = self .randomness ,
510
531
** factory_kwargs ,
532
+ ** self .vsa_kwargs ,
511
533
)
512
534
self .weight .copy_ (embeddings )
513
535
@@ -592,6 +614,7 @@ def __init__(
592
614
sparse : bool = False ,
593
615
device = None ,
594
616
dtype = None ,
617
+ ** kwargs ,
595
618
) -> None :
596
619
factory_kwargs = {"device" : device , "dtype" : dtype }
597
620
# Have to call Module init explicitly in order not to use the Embedding init
@@ -600,6 +623,7 @@ def __init__(
600
623
self .num_embeddings = num_embeddings
601
624
self .embedding_dim = embedding_dim
602
625
self .vsa = vsa
626
+ self .vsa_kwargs = kwargs
603
627
self .low = low
604
628
self .high = high
605
629
@@ -610,7 +634,7 @@ def __init__(
610
634
self .sparse = sparse
611
635
612
636
embeddings = functional .thermometer (
613
- num_embeddings , embedding_dim , self .vsa , ** factory_kwargs
637
+ num_embeddings , embedding_dim , self .vsa , ** factory_kwargs , ** self . vsa_kwargs
614
638
)
615
639
# Have to provide requires grad at the creation of the parameters to
616
640
# prevent errors when instantiating a non-float embedding
@@ -621,7 +645,11 @@ def reset_parameters(self) -> None:
621
645
622
646
with torch .no_grad ():
623
647
embeddings = functional .thermometer (
624
- self .num_embeddings , self .embedding_dim , self .vsa , ** factory_kwargs
648
+ self .num_embeddings ,
649
+ self .embedding_dim ,
650
+ self .vsa ,
651
+ ** factory_kwargs ,
652
+ ** self .vsa_kwargs ,
625
653
)
626
654
self .weight .copy_ (embeddings )
627
655
@@ -704,6 +732,7 @@ def __init__(
704
732
sparse : bool = False ,
705
733
device = None ,
706
734
dtype = None ,
735
+ ** kwargs ,
707
736
) -> None :
708
737
factory_kwargs = {"device" : device , "dtype" : dtype }
709
738
# Have to call Module init explicitly in order not to use the Embedding init
@@ -712,6 +741,7 @@ def __init__(
712
741
self .num_embeddings = num_embeddings
713
742
self .embedding_dim = embedding_dim
714
743
self .vsa = vsa
744
+ self .vsa_kwargs = kwargs
715
745
self .phase = phase
716
746
self .period = period
717
747
self .randomness = randomness
@@ -728,6 +758,7 @@ def __init__(
728
758
self .vsa ,
729
759
randomness = randomness ,
730
760
** factory_kwargs ,
761
+ ** self .vsa_kwargs ,
731
762
)
732
763
# Have to provide requires grad at the creation of the parameters to
733
764
# prevent errors when instantiating a non-float embedding
@@ -743,6 +774,7 @@ def reset_parameters(self) -> None:
743
774
self .vsa ,
744
775
randomness = self .randomness ,
745
776
** factory_kwargs ,
777
+ ** self .vsa_kwargs ,
746
778
)
747
779
self .weight .copy_ (embeddings )
748
780
@@ -945,6 +977,7 @@ def __init__(
945
977
device = None ,
946
978
dtype = None ,
947
979
requires_grad : bool = False ,
980
+ ** kwargs ,
948
981
):
949
982
factory_kwargs = {
950
983
"device" : device ,
@@ -954,10 +987,16 @@ def __init__(
954
987
super (Density , self ).__init__ ()
955
988
956
989
# A set of random vectors used as unique IDs for features of the dataset.
957
- self .key = Random (in_features , out_features , vsa , ** factory_kwargs )
990
+ self .key = Random (in_features , out_features , vsa , ** factory_kwargs , ** kwargs )
958
991
# Thermometer encoding used for transforming input data.
959
992
self .density_encoding = Thermometer (
960
- out_features + 1 , out_features , vsa , low = low , high = high , ** factory_kwargs
993
+ out_features + 1 ,
994
+ out_features ,
995
+ vsa ,
996
+ low = low ,
997
+ high = high ,
998
+ ** factory_kwargs ,
999
+ ** kwargs ,
961
1000
)
962
1001
963
1002
def reset_parameters (self ) -> None :
0 commit comments