Skip to content

Commit 553ea54

Browse files
committed
fix: remove unused code
1 parent 4e759ec commit 553ea54

File tree

1 file changed

+0
-16
lines changed

1 file changed

+0
-16
lines changed

MaxText/layers/embeddings.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -686,22 +686,6 @@ def __call__(self, inputs: Array, position: None | Array = None) -> Array:
686686
output = output.astype(self.fprop_dtype)
687687
return output
688688

689-
690-
def positional_embedding_as_linen(*, embedding_dims: int, max_wavelength: int = _MAX_WAVELENGTH):
691-
"""Initializes the PositionalEmbedding module and returns it as a Linen module.
692-
693-
Args:
694-
embedding_dims: The dimension of the embeddings.
695-
max_wavelength: The maximum wavelength for the sinusoidal positional embeddings.
696-
"""
697-
return nnx_wrappers.to_linen(
698-
PositionalEmbedding,
699-
embedding_dims=embedding_dims,
700-
max_wavelength=max_wavelength,
701-
metadata_fn=variable_to_logically_partitioned,
702-
)
703-
704-
705689
@dataclasses.dataclass(repr=False)
706690
class PositionalEmbedding(nnx.Module):
707691
"""A layer that adds sinusoidal positional embeddings to the input.

0 commit comments

Comments
 (0)