diff --git a/.gitignore b/.gitignore index 416f213f2c8..b753869be7d 100644 --- a/.gitignore +++ b/.gitignore @@ -20,4 +20,5 @@ examples/**/*.jpg .python-version .coverage *coverage.xml -.ruff_cache \ No newline at end of file +.ruff_cache +*.log \ No newline at end of file diff --git a/log.log b/log.log deleted file mode 100644 index df06bfe8567..00000000000 --- a/log.log +++ /dev/null @@ -1,162 +0,0 @@ -============================= test session starts ============================== -platform darwin -- Python 3.12.10, pytest-8.4.2, pluggy-1.6.0 -- /Users/wenyiguo/keras/venv/bin/python3.12 -cachedir: .pytest_cache -rootdir: /Users/wenyiguo/keras -configfile: pyproject.toml -plugins: cov-7.0.0 -collecting ... collected 1 item - -keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method FAILED - -=================================== FAILURES =================================== -________ TestJaxLayer.test_flax_layer_training_independent_bound_method ________ - -self = -flax_model_class = -flax_model_method = 'forward', init_kwargs = {}, trainable_weights = 8 -trainable_params = 648226, non_trainable_weights = 0, non_trainable_params = 0 - - @parameterized.named_parameters( - { - "testcase_name": "training_independent_bound_method", - "flax_model_class": "FlaxTrainingIndependentModel", - "flax_model_method": "forward", - "init_kwargs": {}, - "trainable_weights": 8, - "trainable_params": 648226, - "non_trainable_weights": 0, - "non_trainable_params": 0, - }, - { - "testcase_name": "training_rng_unbound_method", - "flax_model_class": "FlaxDropoutModel", - "flax_model_method": None, - "init_kwargs": { - "method": "flax_dropout_wrapper", - }, - "trainable_weights": 8, - "trainable_params": 648226, - "non_trainable_weights": 0, - "non_trainable_params": 0, - }, - { - "testcase_name": "training_rng_state_no_method", - "flax_model_class": "FlaxBatchNormModel", - "flax_model_method": None, - "init_kwargs": {}, - "trainable_weights": 13, - "trainable_params": 354258, - "non_trainable_weights": 8, - "non_trainable_params": 536, - }, - { - "testcase_name": "training_rng_unbound_method_dtype_policy", - "flax_model_class": "FlaxDropoutModel", - "flax_model_method": None, - "init_kwargs": { - "method": "flax_dropout_wrapper", - "dtype": DTypePolicy("mixed_float16"), - }, - "trainable_weights": 8, - "trainable_params": 648226, - "non_trainable_weights": 0, - "non_trainable_params": 0, - }, - ) - @pytest.mark.skipif(flax is None, reason="Flax library is not available.") - def test_flax_layer( - self, - flax_model_class, - flax_model_method, - init_kwargs, - trainable_weights, - trainable_params, - non_trainable_weights, - non_trainable_params, - ): - flax_model_class = FLAX_OBJECTS.get(flax_model_class) - if "method" in init_kwargs: - init_kwargs["method"] = FLAX_OBJECTS.get(init_kwargs["method"]) - - def create_wrapper(**kwargs): - params = kwargs.pop("params") if "params" in kwargs else None - state = kwargs.pop("state") if "state" in kwargs else None - if params and state: - variables = {**params, **state} - elif params: - variables = params - elif state: - variables = state - else: - variables = None - kwargs["variables"] = variables - flax_model = flax_model_class() - if flax_model_method: - kwargs["method"] = getattr(flax_model, flax_model_method) - return FlaxLayer(flax_model_class(), **kwargs) - -> self._test_layer( - flax_model_class.__name__, - create_wrapper, - init_kwargs, - trainable_weights, - trainable_params, - non_trainable_weights, - non_trainable_params, - ) - -keras/src/utils/jax_layer_test.py:488: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -keras/src/utils/jax_layer_test.py:231: in _test_layer - outputs1 = layer1(inputs1) - ^^^^^^^^^^^^^^^ -keras/src/utils/traceback_utils.py:113: in error_handler - return fn(*args, **kwargs) - ^^^^^^^^^^^^^^^^^^^ -keras/src/layers/layer.py:866: in __call__ - self._maybe_build(call_spec) -keras/src/layers/layer.py:1477: in _maybe_build - self.build(**shapes_dict) -keras/src/layers/layer.py:231: in build_wrapper - original_build_method(*args, **kwargs) -keras/src/utils/jax_layer.py:510: in build - self._initialize_weights(input_shape) -keras/src/utils/jax_layer.py:497: in _initialize_weights - init_result = self.init_fn(*init_args) - ^^^^^^^^^^^^^^^^^^^^^^^^ -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -rng = {'dropout': None, 'params': None} -inputs = Array([[[[1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.]... [1.], - [1.], - [1.], - [1.], - [1.], - [1.], - [1.]]]], dtype=float32) - - def init_without_training(rng, inputs): - return self._variables_to_params_and_state( -> self.module.init( - rng, - inputs, - method=self.method, - ) - ) -E ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``. -E -------------------- -E For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. - -keras/src/utils/jax_layer.py:755: ValueError -=========================== short test summary info ============================ -FAILED keras/src/utils/jax_layer_test.py::TestJaxLayer::test_flax_layer_training_independent_bound_method - ValueError: First argument passed to an init function should be a ``jax.PRNGKey`` or a dictionary mapping strings to ``jax.PRNGKey``. --------------------- -For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these. -============================== 1 failed in 1.72s ===============================