I want to freeze parts of my network for training, so to do this, I modified the ppo.train function to accept an optimizer object.
Traceback (most recent call last):
File "/Users/scripts/train.py", line 181, in <module>
main()
File "/Users/scripts/train.py", line 79, in main
tx = optimizer.init(init_params)
File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/optax/_src/combine.py", line 214, in init_fn
inner_states = {
File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/optax/_src/combine.py", line 217, in <dictcomp>
mask_compatible_extra_args=mask_compatible_extra_args).init(params)
File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/optax/_src/wrappers.py", line 545, in init_fn
masked_params = mask_pytree(params, mask_tree)
File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/optax/_src/wrappers.py", line 509, in mask_pytree
return jax.tree_util.tree_map(
File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/jax/_src/tree_util.py", line 311, in tree_map
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
File "/Users/.pyenv/versions/3.10.12/lib/python3.10/site-packages/jax/_src/tree_util.py", line 311, in <listcomp>
all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
ValueError: Expected dict, got PPONetworkParams(policy={'params': {'ResidualPolicy_0': {'Dense_0': {'kernel': Array([[-0.081, 0.018, -0.062, ..., 0.081, 0.015, 0.01 ],
[ 0.018, 0.075, 0.039, ..., 0.035, 0.072, -0.018],
[-0.046, 0.078, 0.027, ..., 0.038, 0.04 , -0.08 ],
...,
[ 0.012, 0.029, -0.009, ..., 0.01 , -0.017, 0.022],
[ 0.007, 0.034, 0.002, ..., -0.087, 0.059, 0.08 ],
[ 0.084, 0.024, -0.017, ..., -0.061, -0.087, 0.007]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'Dense_1': {'kernel': Array([[ 0.079, 0.092, 0.13 , ..., 0.01 , -0.024, -0.124],
[ 0.037, -0.038, 0.072, ..., 0.13 , 0.051, -0.032],
[-0.087, -0.007, -0.026, ..., 0.051, 0.071, -0.025],
...,
[ 0.033, -0.073, -0.038, ..., 0.154, 0.025, 0.062],
[-0.182, -0.127, 0.011, ..., -0.003, -0.08 , 0.029],
[-0.115, -0.125, 0.094, ..., -0.041, -0.002, -0.139]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'Dense_2': {'kernel': Array([[-0.008, 0.126, -0.005, ..., -0.137, -0.017, 0.149],
[ 0.07 , -0.039, -0.172, ..., -0.043, -0.088, 0.138],
[ 0.146, -0.02 , 0.142, ..., 0.044, 0.013, 0.092],
...,
[-0.132, -0.006, 0.021, ..., 0.028, -0.068, -0.086],
[-0.065, 0.124, 0.024, ..., -0.073, -0.087, 0.068],
[ 0.154, -0.026, -0.121, ..., 0.008, -0.054, 0.035]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0.], dtype=float32)}}}}, value={'params': {'hidden_0': {'kernel': Array([[ 0.061, -0.025, 0.026, ..., -0.067, 0.02 , -0.025],
[ 0.027, -0.046, -0.015, ..., -0.003, 0.04 , 0.057],
[-0.04 , 0.069, -0.004, ..., -0.06 , 0.067, 0.046],
...,
[ 0.032, 0.055, 0.014, ..., -0.068, 0.02 , -0.005],
[-0.016, 0.048, -0.002, ..., -0.07 , -0.062, -0.077],
[-0.004, -0.064, 0.032, ..., 0.004, 0.001, 0.002]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.], dtype=float32)}, 'hidden_1': {'kernel': Array([[ 0.003, 0.021, -0.014, ..., -0.06 , 0.055, 0.096],
[-0.028, -0.01 , -0.013, ..., 0.025, 0.016, 0.019],
[ 0.02 , 0.026, 0.029, ..., 0.078, 0.099, -0.078],
...,
[ 0.071, -0.071, 0.03 , ..., 0.104, -0.084, 0.103],
[ 0.011, 0.095, -0.06 , ..., -0. , 0.071, -0.007],
[-0.039, 0.024, -0.053, ..., 0.051, -0.02 , 0.008]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.], dtype=float32)}, 'hidden_2': {'kernel': Array([[-0.047, 0.041, -0.093, ..., 0.084, 0.103, 0.034],
[-0.018, -0.038, 0.016, ..., 0.066, 0.052, 0.105],
[ 0.002, 0.044, 0.059, ..., -0.008, -0.037, -0.102],
...,
[-0.045, -0.085, 0.017, ..., 0.024, -0.07 , -0.034],
[ 0.1 , -0.042, 0.026, ..., 0.051, -0.103, 0.056],
[ 0.002, 0.095, -0.073, ..., -0.049, 0.018, -0.077]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.], dtype=float32)}, 'hidden_3': {'kernel': Array([[ 0.08 , -0.093, -0.103, ..., -0.04 , 0.063, -0.052],
[-0.028, -0.087, -0.056, ..., -0.017, -0.072, -0.107],
[ 0.039, -0.024, -0.021, ..., -0.052, 0.028, -0.012],
...,
[ 0.036, 0.087, -0.033, ..., 0.073, -0.05 , 0.044],
[-0.102, 0.069, 0.002, ..., 0.075, 0.095, -0.098],
[-0.066, -0.019, 0.053, ..., 0.043, 0.025, -0.092]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.], dtype=float32)}, 'hidden_4': {'kernel': Array([[ 0.08 , -0.102, -0.031, ..., 0.009, -0.012, 0.017],
[ 0.038, 0.081, 0.033, ..., 0.059, 0.056, 0.018],
[ 0.06 , -0.024, 0.034, ..., -0.046, 0.044, 0.021],
...,
[-0.048, -0.052, -0.099, ..., -0.05 , 0.016, 0.003],
[-0.101, 0.042, -0.023, ..., -0.027, -0.016, 0.061],
[ 0.014, -0.084, -0.028, ..., 0.079, 0.049, 0.104]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0.], dtype=float32)}, 'hidden_5': {'kernel': Array([[ 0.091],
[-0.073],
[ 0.103],
[ 0.062],
[-0.028],
[ 0.078],
[-0.066],
[ 0.047],
[ 0.066],
[-0.095],
[ 0.004],
[ 0.017],
[ 0.038],
[-0.031],
[ 0.108],
[-0.048],
[ 0.065],
[ 0.099],
[ 0.031],
[ 0.023],
[ 0.014],
[ 0.025],
[ 0.086],
[ 0.087],
[ 0.052],
[ 0.029],
[-0.079],
[ 0.087],
[-0.016],
[-0.02 ],
[-0.1 ],
[ 0.086],
[ 0.018],
[-0.021],
[-0.025],
[ 0.068],
[-0.068],
[ 0.071],
[ 0.013],
[ 0.073],
[ 0.076],
[ 0.098],
[ 0.035],
[ 0.074],
[-0.052],
[ 0.044],
[-0.073],
[-0.006],
[-0.033],
[-0.091],
[-0.037],
[-0.006],
[ 0.003],
[ 0.082],
[ 0.001],
[-0.001],
[ 0.083],
[ 0.003],
[ 0.088],
[ 0.023],
[-0.031],
[ 0.027],
[-0.016],
[ 0.046],
[ 0.038],
[-0.062],
[-0.011],
[-0.037],
[-0.07 ],
[-0.065],
[ 0.033],
[ 0.1 ],
[-0.084],
[ 0.055],
[ 0.024],
[-0.083],
[ 0.031],
[ 0.048],
[-0.1 ],
[ 0.098],
[-0.037],
[ 0.084],
[-0.004],
[ 0.004],
[-0.029],
[ 0.071],
[ 0.031],
[ 0.069],
[ 0.079],
[-0.064],
[ 0.068],
[-0.055],
[-0.085],
[ 0.102],
[-0.032],
[-0.08 ],
[-0.094],
[-0.098],
[-0.097],
[ 0.041],
[-0.015],
[ 0.032],
[ 0.048],
[-0.073],
[ 0.071],
[-0.098],
[ 0.072],
[ 0.051],
[-0.031],
[ 0.078],
[-0.001],
[-0.052],
[-0.011],
[-0.003],
[-0.003],
[-0.01 ],
[ 0.013],
[-0.058],
[-0.076],
[ 0.107],
[-0.014],
[ 0.102],
[ 0.054],
[-0.047],
[-0.095],
[-0.041],
[-0.049],
[ 0.043],
[-0.092],
[ 0.016],
[ 0.026],
[ 0.098],
[-0.101],
[ 0.065],
[-0.027],
[ 0.085],
[ 0.093],
[-0.105],
[ 0.079],
[-0.036],
[ 0.089],
[-0.008],
[-0.05 ],
[-0.072],
[ 0.094],
[-0.001],
[-0.042],
[ 0.049],
[-0.065],
[-0.011],
[-0.083],
[-0.008],
[ 0.011],
[-0.032],
[-0.052],
[ 0.052],
[ 0.026],
[-0.069],
[-0.01 ],
[ 0.059],
[-0.079],
[-0.071],
[-0.019],
[-0.041],
[-0.052],
[-0.053],
[-0.072],
[-0.083],
[ 0.017],
[ 0.071],
[ 0.067],
[ 0.002],
[-0.042],
[-0.085],
[-0.006],
[-0.016],
[-0.086],
[-0.101],
[ 0.06 ],
[-0.067],
[-0.052],
[ 0.004],
[ 0.076],
[ 0.075],
[-0.106],
[-0.044],
[-0.066],
[ 0.086],
[ 0.05 ],
[-0.083],
[ 0.105],
[ 0.08 ],
[ 0.103],
[ 0.072],
[ 0.024],
[ 0. ],
[ 0.065],
[ 0.025],
[ 0.047],
[-0.083],
[ 0.014],
[ 0.059],
[ 0.072],
[-0.058],
[ 0.091],
[-0.033],
[-0.011],
[-0.097],
[-0.077],
[ 0.049],
[-0.058],
[-0.053],
[-0.081],
[-0.032],
[ 0.06 ],
[-0.093],
[-0.087],
[-0.064],
[-0.008],
[-0.052],
[-0.058],
[ 0.072],
[-0.031],
[ 0.07 ],
[-0.068],
[-0.102],
[ 0.045],
[-0.104],
[ 0.097],
[-0.054],
[-0.035],
[ 0.021],
[ 0.015],
[ 0.064],
[-0.006],
[-0.008],
[ 0.068],
[-0.014],
[ 0.065],
[ 0.034],
[ 0.002],
[ 0.014],
[-0.103],
[ 0.082],
[-0.088],
[-0.029],
[-0.006],
[ 0.083],
[ 0.063],
[-0.049],
[ 0.099],
[-0.012],
[-0.018],
[ 0.077],
[ 0.054],
[-0.095]], dtype=float32), 'bias': Array([0.], dtype=float32)}}}).
I want to freeze parts of my network for training, so to do this, I modified the
ppo.trainfunction to accept an optimizer object.I then define the optimizer as
However, when the optimizer calls init as
optimizer.init(init_params)when defining the TrainingState, it throws the following error:This does not happen when I define the optimizer as
op = optax.adam(3.0e-4). Since both variants areoptax.GradientTransformationExtraArgs, can someone please explain what is happening and how can I resolve this?Here's the full stack trace: