Skip to content

Commit b2979c2

Browse files
yanboliangfchollet
authored andcommitted
Network.to_json should handle numpy.ndarray correctly. (#10754)
1 parent da9ce7d commit b2979c2

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

keras/engine/network.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1193,7 +1193,10 @@ def to_json(self, **kwargs):
11931193
def get_json_type(obj):
11941194
# If obj is any numpy type
11951195
if type(obj).__module__ == np.__name__:
1196-
return obj.item()
1196+
if isinstance(obj, np.ndarray):
1197+
return obj.tolist()
1198+
else:
1199+
return obj.item()
11971200

11981201
# If obj is a python 'type'
11991202
if type(obj).__name__ == type.__name__:

tests/keras/engine/test_topology.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from keras import backend as K
1010
from keras.models import model_from_json, model_from_yaml
1111
from keras.utils.test_utils import keras_test
12+
from keras.initializers import Constant
1213

1314

1415
skipif_no_tf_gpu = pytest.mark.skipif(
@@ -797,5 +798,19 @@ def call(self, inputs, **kwargs):
797798
assert K.int_shape(z)[1:] == (16, 16, 3)
798799

799800

801+
@keras_test
802+
def test_constant_initializer_with_numpy():
803+
model = Sequential()
804+
model.add(Dense(2, input_shape=(3,), kernel_initializer=Constant(np.ones((3, 2)))))
805+
model.add(Dense(3))
806+
model.compile(loss='mse', optimizer='sgd', metrics=['acc'])
807+
808+
json_str = model.to_json()
809+
model_from_json(json_str).summary()
810+
811+
yaml_str = model.to_yaml()
812+
model_from_yaml(yaml_str).summary()
813+
814+
800815
if __name__ == '__main__':
801816
pytest.main([__file__])

0 commit comments

Comments
 (0)