30
30
import optax
31
31
import orbax .checkpoint as ocp
32
32
from mlrx .training import core
33
- from mlrx .training import jax as jax_lib
33
+ from mlrx .training import jax_trainer
34
34
from mlrx .training import partitioning
35
35
import tensorflow as tf
36
36
@@ -42,7 +42,7 @@ def __call__(self, inputs: jax.Array) -> jax.Array:
42
42
return nn .Dense (1 , kernel_init = nn .initializers .constant (- 1.0 ))(inputs )
43
43
44
44
45
- class _JaxTask (jax_lib .JaxTask ):
45
+ class _JaxTask (jax_trainer .JaxTask ):
46
46
47
47
def create_datasets (
48
48
self ,
@@ -90,7 +90,7 @@ def eval_step(
90
90
return {"loss" : clu_metrics .Average .from_model_output (loss )}
91
91
92
92
93
- class _KerasJaxTask (jax_lib .JaxTask ):
93
+ class _KerasJaxTask (jax_trainer .JaxTask ):
94
94
95
95
def create_datasets (self ) -> tf .data .Dataset :
96
96
def _map_fn (x : int ):
@@ -106,7 +106,7 @@ def _map_fn(x: int):
106
106
107
107
def create_state (
108
108
self , batch : jt .PyTree , rng : jax .Array
109
- ) -> jax_lib .KerasState :
109
+ ) -> jax_trainer .KerasState :
110
110
x , _ = batch
111
111
112
112
model = keras .Sequential (
@@ -122,11 +122,11 @@ def create_state(
122
122
model .build (x .shape )
123
123
124
124
optimizer = optax .adagrad (0.1 )
125
- return jax_lib .KerasState .create (model = model , tx = optimizer )
125
+ return jax_trainer .KerasState .create (model = model , tx = optimizer )
126
126
127
127
def train_step (
128
- self , batch : jt .PyTree , state : jax_lib .KerasState , rng : jax .Array
129
- ) -> tuple [jax_lib .KerasState , Mapping [str , clu_metrics .Metric ]]:
128
+ self , batch : jt .PyTree , state : jax_trainer .KerasState , rng : jax .Array
129
+ ) -> tuple [jax_trainer .KerasState , Mapping [str , clu_metrics .Metric ]]:
130
130
x , y = batch
131
131
132
132
def _loss_fn (tvars ):
@@ -140,7 +140,7 @@ def _loss_fn(tvars):
140
140
return state , {"loss" : clu_metrics .Average .from_model_output (loss )}
141
141
142
142
def eval_step (
143
- self , batch : jt .PyTree , state : jax_lib .KerasState
143
+ self , batch : jt .PyTree , state : jax_trainer .KerasState
144
144
) -> Mapping [str , clu_metrics .Metric ]:
145
145
x , y = batch
146
146
y_pred , _ = state .model .stateless_call (state .tvars , state .ntvars , x )
@@ -208,13 +208,13 @@ def setUp(self):
208
208
)
209
209
def test_jax_trainer (
210
210
self ,
211
- task_cls : type [jax_lib .JaxTask ],
211
+ task_cls : type [jax_trainer .JaxTask ],
212
212
mode : str ,
213
213
expected_keys : Sequence [str ],
214
214
):
215
215
model_dir = self .create_tempdir ().full_path
216
216
task = task_cls ()
217
- trainer = jax_lib .JaxTrainer (
217
+ trainer = jax_trainer .JaxTrainer (
218
218
partitioner = partitioning .DataParallelPartitioner (data_axis = "batch" ),
219
219
train_steps = 12 ,
220
220
steps_per_eval = 3 ,
@@ -258,7 +258,7 @@ class State:
258
258
),
259
259
)
260
260
state = State (step = 10 , opt_state = tx .init ({"a" : jnp .ones ((10 , 10 ))}))
261
- metrics = jax_lib ._state_metrics (state )
261
+ metrics = jax_trainer ._state_metrics (state )
262
262
self .assertIn ("optimizer/learning_rate" , metrics )
263
263
self .assertEqual (metrics ["optimizer/learning_rate" ].compute (), 0.1 )
264
264
0 commit comments