@@ -277,15 +277,16 @@ def _posemb_sincos_2d(
277
277
width : int ,
278
278
temperature : float = 10_000.0 ,
279
279
dtype : jnp .dtype = jnp .float32 ,
280
+ precision : str = "default" ,
280
281
):
281
282
"""Follows the MoCo v3 logic."""
282
283
y , x = jnp .mgrid [:h , :w ] # pylint: disable=unpacking-non-sequence
283
284
284
285
assert width % 4 == 0 , "Width must be mult of 4 for sincos posemb"
285
286
omega = jnp .arange (width // 4 ) / (width // 4 - 1 )
286
287
omega = 1.0 / (temperature ** omega )
287
- y = jnp .einsum ("m,d->md" , y .flatten (), omega )
288
- x = jnp .einsum ("m,d->md" , x .flatten (), omega )
288
+ y = jnp .einsum ("m,d->md" , y .flatten (), omega , precision = jax . lax . Precision ( precision ) )
289
+ x = jnp .einsum ("m,d->md" , x .flatten (), omega , precision = jax . lax . Precision ( precision ) )
289
290
pe = jnp .concatenate ([jnp .sin (x ), jnp .cos (x ), jnp .sin (y ), jnp .cos (y )], axis = 1 )
290
291
return jnp .asarray (pe , dtype )[None , :, :]
291
292
@@ -297,18 +298,22 @@ class MlpBlockViT(nn.Module):
297
298
dtype_mm : str
298
299
mlp_dim : int | None = None # Defaults to 4x input dim
299
300
dropout : float = 0.0
301
+ precision : str = "default"
300
302
301
303
@nn .compact
302
304
def __call__ (self , x : jax .Array , deterministic : bool = True ) -> jax .Array :
303
305
"""Applies Transformer MlpBlock module."""
304
306
inits = {"kernel_init" : nn .initializers .xavier_uniform (), "bias_init" : nn .initializers .normal (stddev = 1e-6 )}
305
307
306
308
d = x .shape [- 1 ]
307
- x = nn .Dense (features = self .mlp_dim or 4 * d , dtype = self .dtype_mm , ** inits )(x )
309
+ x = nn .Dense (features = self .mlp_dim or 4 * d , precision = jax .lax .Precision (self .precision ), dtype = self .dtype_mm , ** inits )(
310
+ x
311
+ )
308
312
x = nn .gelu (x )
309
313
x = nn .Dropout (rate = self .dropout )(x , deterministic )
310
314
x = nn .Dense (
311
315
features = d ,
316
+ precision = jax .lax .Precision (self .precision ),
312
317
dtype = self .dtype_mm ,
313
318
** inits ,
314
319
)(x )
@@ -323,6 +328,7 @@ class Encoder1DBlock(nn.Module):
323
328
mlp_dim : int | None = None # Defaults to 4x input dim
324
329
num_heads : int = 12
325
330
dropout : float = 0.0
331
+ precision : str = "default"
326
332
327
333
@nn .compact
328
334
def __call__ (self , x : jax .Array , deterministic : bool = True ) -> jax .Array :
@@ -331,6 +337,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
331
337
y = nn .MultiHeadDotProductAttention (
332
338
num_heads = self .num_heads ,
333
339
kernel_init = nn .initializers .xavier_uniform (),
340
+ precision = jax .lax .Precision (self .precision ),
334
341
deterministic = deterministic ,
335
342
dtype = self .dtype_mm ,
336
343
)(y , y )
@@ -343,6 +350,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
343
350
mlp_dim = self .mlp_dim ,
344
351
dropout = self .dropout ,
345
352
dtype_mm = self .dtype_mm ,
353
+ precision = self .precision ,
346
354
)(y , deterministic )
347
355
y = nn .Dropout (rate = self .dropout )(y , deterministic )
348
356
x = x + y
@@ -358,7 +366,8 @@ class Encoder(nn.Module):
358
366
mlp_dim : int | None = None # Defaults to 4x input dim
359
367
num_heads : int = 12
360
368
dropout : float = 0.0
361
- scan : bool = False
369
+ scan : bool = False ,
370
+ precision : str = "default" ,
362
371
363
372
@nn .compact
364
373
def __call__ (self , x : jax .Array , deterministic : bool = True ) -> jax .Array :
@@ -383,6 +392,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
383
392
mlp_dim = self .mlp_dim ,
384
393
num_heads = self .num_heads ,
385
394
dropout = self .dropout ,
395
+ precision = self .precision ,
386
396
)(
387
397
x , deterministic
388
398
)
@@ -396,6 +406,7 @@ def __call__(self, x: jax.Array, deterministic: bool = True) -> jax.Array:
396
406
mlp_dim = self .mlp_dim ,
397
407
num_heads = self .num_heads ,
398
408
dropout = self .dropout ,
409
+ precision = self .precision ,
399
410
)
400
411
x = block_cur (x , deterministic )
401
412
x : jax .Array = nn .LayerNorm (name = "encoder_norm" )(x )
@@ -409,6 +420,7 @@ class Einsum(nn.Module):
409
420
weight_name : str = "w"
410
421
initializer : nn .initializers .Initializer = nn .initializers .normal ()
411
422
dtype : jnp .dtype | None = None
423
+ precision : str = "default"
412
424
413
425
@nn .compact
414
426
def __call__ (self , eqn : str , x : jax .Array ) -> jax .Array :
@@ -418,7 +430,7 @@ def __call__(self, eqn: str, x: jax.Array) -> jax.Array:
418
430
self .shape ,
419
431
self .dtype if self .dtype is not None else None ,
420
432
)
421
- return jnp .einsum (eqn , x , w )
433
+ return jnp .einsum (eqn , x , w , precision = jax . lax . Precision ( self . precision ) )
422
434
423
435
424
436
class VisionEmbedder (nn .Module ):
@@ -430,8 +442,10 @@ class VisionEmbedder(nn.Module):
430
442
431
443
def setup (self ):
432
444
if self .vision_proj_dim :
433
- self .mm_soft_embedding_norm = rms_norm (self .vision_proj_dim )
434
- self .mm_input_projection = Einsum ((self .vision_proj_dim , self .config .emb_dim ))
445
+ self .mm_soft_embedding_norm = rms_norm (self .vision_proj_dim , dtype = self .config .dtype_mm )
446
+ self .mm_input_projection = Einsum (
447
+ (self .vision_proj_dim , self .config .emb_dim ), dtype = self .config .dtype_mm , precision = self .config .matmul_precision
448
+ )
435
449
436
450
def encode_vision (self , x : jax .Array ) -> jax .Array :
437
451
x = self .mm_soft_embedding_norm (x )
@@ -494,6 +508,7 @@ def _get_posemb(
494
508
width : int ,
495
509
name : str ,
496
510
dtype : jnp .dtype = jnp .float32 ,
511
+ precision : str = "default" ,
497
512
):
498
513
"""Returns the position embedding."""
499
514
if typ == "learn" :
@@ -505,7 +520,7 @@ def _get_posemb(
505
520
dtype ,
506
521
)
507
522
elif typ == "sincos2d" :
508
- return _posemb_sincos_2d (* seqshape , width = width , dtype = dtype )
523
+ return _posemb_sincos_2d (* seqshape , width = width , dtype = dtype , precision = precision )
509
524
else :
510
525
raise ValueError (f"Unknown posemb type: { typ } " )
511
526
@@ -524,7 +539,15 @@ def __call__(self, inputs, deterministic, train=False):
524
539
b , n , h , w , c = inputs .shape
525
540
x = jnp .reshape (inputs , [b * n , h , w , c ])
526
541
# Gemma3 uses conv2d with stride 14 and kernel size 14 to extract patches.
527
- x = nn .Conv (features = 1152 , kernel_size = (14 , 14 ), strides = 14 , padding = "VALID" , name = "embedding" )(x )
542
+ x = nn .Conv (
543
+ features = 1152 ,
544
+ kernel_size = (14 , 14 ),
545
+ strides = 14 ,
546
+ padding = "VALID" ,
547
+ name = "embedding" ,
548
+ dtype = cfg .dtype_mm ,
549
+ precision = jax .lax .Precision (cfg .matmul_precision ),
550
+ )(x )
528
551
bn , h , w , c = x .shape
529
552
x = jnp .reshape (x , [bn , h * w , c ])
530
553
@@ -535,6 +558,7 @@ def __call__(self, inputs, deterministic, train=False):
535
558
width = c ,
536
559
name = "pos_embedding" ,
537
560
dtype = x .dtype ,
561
+ precision = cfg .matmul_precision ,
538
562
)
539
563
540
564
x = nn .Dropout (rate = self .dropout )(x , not train )
@@ -549,6 +573,7 @@ def __call__(self, inputs, deterministic, train=False):
549
573
remat_policy = cfg .remat_policy_for_vit ,
550
574
dtype_mm = cfg .dtype_mm ,
551
575
name = "Transformer" ,
576
+ precision = cfg .matmul_precision ,
552
577
)(x , deterministic = deterministic )
553
578
554
579
# Gemma3 use a vision exit layer to downsample the soft tokens to a required output length.
0 commit comments