Author: @8bitmp3 for google/flax docs
This guide provides an overview of how to apply
dropout
using flax.linen.Dropout.
Dropout is a stochastic regularization technique that randomly removes hidden and visible units in a network.
import flax.linen as nn
import jax.numpy as jnp
import jax import optaxSince dropout is a random operation, it requires a pseudorandom number generator (PRNG) state. Flax uses JAX's (splittable) PRNG keys, which have a number of desirable properties for neutral networks. To learn more, refer to the Pseudorandom numbers in JAX tutorial.
Note: Recall that JAX has an explicit way of giving you PRNG keys:
you can fork the main PRNG state (such as
key = jax.random.PRNGKey(seed=0)) into multiple new PRNG keys with
key, subkey = jax.random.split(key). You can refresh your memory in
🔪 JAX - The Sharp Bits 🔪 Randomness and PRNG keys.
Begin by splitting the PRNG key using jax.random.split()
into three keys, including one for Flax Linen Dropout.
root_key = jax.random.PRNGKey(seed=0)
main_key, params_key, dropout_key = jax.random.split(key=root_key, num=3)Note: In Flax, you provide PRNG streams with names, so that you
can use them later in your flax.linen.Module{.interpreted-text
role="meth"}. For example, you pass the stream 'params' for
initializing parameters, and 'dropout' for applying
flax.linen.Dropout.
To create a model with dropout:
- Subclass
flax.linen.Module, and then useflax.linen.Dropoutto add a dropout layer. Recall thatflax.linen.Moduleis the base class for all neural network Modules, and all layers and models are subclassed from it. - In
flax.linen.Dropout, thedeterministicargument is required to be passed as a keyword argument, either:- When constructing the
flax.linen.Module; or - When calling
flax.linen.init()orflax.linen.apply()on a constructedModule. (Refer toflax.linen.module.merge_paramfor more details.)
- When constructing the
- Because
deterministicis a boolean:- If it's set to
False, the inputs are masked (that is, set to zero) with a probability set byrate. And the remaining inputs are scaled by1 / (1 - rate), which ensures that the means of the inputs are preserved. - If it's set to
True, no mask is applied (the dropout is turned off), and the inputs are returned as-is.
- If it's set to
A common pattern is to accept a training (or train) argument (a
boolean) in the parent Flax Module, and use it to enable or disable
dropout (as demonstrated in later sections of this guide). In other
machine learning frameworks, like PyTorch or TensorFlow (Keras), this is
specified via a mutable state or a call flag (for example, in
torch.nn.Module.eval
or tf.keras.Model by setting the
training flag).
Note: Flax provides an implicit way of handling PRNG key streams via
Flax flax.linen.Module's flax.linen.Module.make_rng method. This
allows you to split off a fresh PRNG key inside Flax Modules (or their
sub-Modules) from the PRNG stream. The make_rng method guarantees to
provide a unique key each time you call it. Internally,
flax.linen.Dropout makes use of
flax.linen.Module.make_rng to create a
key for dropout. You can check out the source code.
In short, flax.linen.Module.make_rng guarantees full reproducibility.
class MyModel(nn.Module):
num_neurons: int
@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(self.num_neurons)(x)
# Set the dropout layer with a `rate` of 50%.
# When the `deterministic` flag is `True`, dropout is turned off.
x = nn.Dropout(rate=0.5, deterministic=not training)(x)
return xAfter creating your model:
- Instantiate the model.
- Then, in the
flax.linen.init()call, settraining=False. - Finally, extract the
paramsfrom the variable dictionary.
Here, the main difference between the code without Flax Dropout and
with Dropout is that the training (or train) argument must be
provided if you need dropout enabled.
my_model = MyModel(num_neurons=3)
x = jnp.empty((3, 4, 4))
# Dropout is disabled with `training=False` (that is, `deterministic=True`).
variables = my_model.init(params_key, x, training=False)
params = variables['params']When using flax.linen.apply() to run your model:
- Pass
training=Truetoflax.linen.apply(). - Then, to draw PRNG keys during the forward pass (with dropout),
provide a PRNG key to seed the
'dropout'stream when you callflax.linen.apply().
# Dropout is enabled with `training=True` (that is, `deterministic=False`).
y = my_model.apply({'params': params}, x, training=True, rngs={'dropout': dropout_key})Here, the main difference between the code without Flax Dropout and
with Dropout is that the training (or train) and rngs arguments
must be provided if you need dropout enabled.
During evaluation, use the above code with no dropout enabled (this means you do not have to pass a RNG either).
This section explains how to amend your code inside the training step function if you have dropout enabled.
Note: Recall that Flax has a common pattern where you create a
dataclass that represents the whole training state, including parameters
and the optimizer state. Then, you can pass a single parameter,
state: TrainState, to the training step function. Refer to the
flax.training.train_state.TrainState API docs to learn more.
- First, add a
keyfield to a customflax.training.train_state.TrainState{.interpreted-text role="meth"} class. - Then, pass the
keyvalue - in this case, thedropout_key- to thetrain_state.TrainState.createmethod.
from flax.training import train_state
class TrainState(train_state.TrainState):
key: jax.random.KeyArray
state = TrainState.create(
apply_fn=my_model.apply,
params=params,
key=dropout_key,
tx=optax.adam(1e-3)
)-
Next, in the Flax training step function,
train_step, generate a new PRNG key from thedropout_keyto apply dropout at each step. This can be done with one of the following:Using
jax.random.fold_in()is generally faster. When you usejax.random.split()you split off a PRNG key that can be reused afterwards. However, usingjax.random.fold_in()makes sure to: 1) fold in unique data; and 2) can result in longer sequences of PRNG streams. -
Finally, when performing the forward pass, pass the new PRNG key to
state.apply_fn()as an extra parameter.
@jax.jit
def train_step(state: TrainState, batch, dropout_key):
dropout_train_key = jax.random.fold_in(key=dropout_key, data=state.step)
def loss_fn(params):
logits = state.apply_fn(
{'params': params},
x=batch['image'],
training=True,
rngs={'dropout': dropout_train_key}
)
loss = optax.softmax_cross_entropy_with_integer_labels(
logits=logits, labels=batch['label'])
return loss, logits
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, logits), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads)
return state- A Transformer-based model trained on the WMT Machine Translation dataset. This example uses dropout and attention dropout.
- Applying word dropout to a batch of input IDs in a text classification
context. This example uses a custom
flax.linen.Dropoutlayer.
- Defining a prediction token in a decoder of a sequence-to-sequence model.