@@ -10,6 +10,11 @@ def test_config(self):
1010 self .assertEqual (acc_obj .name , "accuracy" )
1111 self .assertEqual (len (acc_obj .variables ), 2 )
1212 self .assertEqual (acc_obj ._dtype , "float32" )
13+
14+ # Test get_config
15+ acc_obj_config = acc_obj .get_config ()
16+ self .assertEqual (acc_obj_config ["name" ], "accuracy" )
17+ self .assertEqual (acc_obj_config ["dtype" ], "float32" )
1318 # TODO: Check save and restore config
1419
1520 def test_unweighted (self ):
@@ -38,6 +43,11 @@ def test_config(self):
3843 self .assertEqual (bin_acc_obj .name , "binary_accuracy" )
3944 self .assertEqual (len (bin_acc_obj .variables ), 2 )
4045 self .assertEqual (bin_acc_obj ._dtype , "float32" )
46+
47+ # Test get_config
48+ bin_acc_obj_config = bin_acc_obj .get_config ()
49+ self .assertEqual (bin_acc_obj_config ["name" ], "binary_accuracy" )
50+ self .assertEqual (bin_acc_obj_config ["dtype" ], "float32" )
4151 # TODO: Check save and restore config
4252
4353 def test_unweighted (self ):
@@ -70,6 +80,11 @@ def test_config(self):
7080 self .assertEqual (cat_acc_obj .name , "categorical_accuracy" )
7181 self .assertEqual (len (cat_acc_obj .variables ), 2 )
7282 self .assertEqual (cat_acc_obj ._dtype , "float32" )
83+
84+ # Test get_config
85+ cat_acc_obj_config = cat_acc_obj .get_config ()
86+ self .assertEqual (cat_acc_obj_config ["name" ], "categorical_accuracy" )
87+ self .assertEqual (cat_acc_obj_config ["dtype" ], "float32" )
7388 # TODO: Check save and restore config
7489
7590 def test_unweighted (self ):
@@ -102,6 +117,13 @@ def test_config(self):
102117 self .assertEqual (sp_cat_acc_obj .name , "sparse_categorical_accuracy" )
103118 self .assertEqual (len (sp_cat_acc_obj .variables ), 2 )
104119 self .assertEqual (sp_cat_acc_obj ._dtype , "float32" )
120+
121+ # Test get_config
122+ sp_cat_acc_obj_config = sp_cat_acc_obj .get_config ()
123+ self .assertEqual (
124+ sp_cat_acc_obj_config ["name" ], "sparse_categorical_accuracy"
125+ )
126+ self .assertEqual (sp_cat_acc_obj_config ["dtype" ], "float32" )
105127 # TODO: Check save and restore config
106128
107129 def test_unweighted (self ):
@@ -124,3 +146,90 @@ def test_weighted(self):
124146 sp_cat_acc_obj .update_state (y_true , y_pred , sample_weight = sample_weight )
125147 result = sp_cat_acc_obj .result ()
126148 self .assertAllClose (result , 0.3 , atol = 1e-3 )
149+
150+
151+ class TopKCategoricalAccuracyTest (testing .TestCase ):
152+ def test_config (self ):
153+ top_k_cat_acc_obj = accuracy_metrics .TopKCategoricalAccuracy (
154+ k = 1 , name = "top_k_categorical_accuracy" , dtype = "float32"
155+ )
156+ self .assertEqual (top_k_cat_acc_obj .name , "top_k_categorical_accuracy" )
157+ self .assertEqual (len (top_k_cat_acc_obj .variables ), 2 )
158+ self .assertEqual (top_k_cat_acc_obj ._dtype , "float32" )
159+
160+ # Test get_config
161+ top_k_cat_acc_obj_config = top_k_cat_acc_obj .get_config ()
162+ self .assertEqual (
163+ top_k_cat_acc_obj_config ["name" ], "top_k_categorical_accuracy"
164+ )
165+ self .assertEqual (top_k_cat_acc_obj_config ["dtype" ], "float32" )
166+ self .assertEqual (top_k_cat_acc_obj_config ["k" ], 1 )
167+ # TODO: Check save and restore config
168+
169+ def test_unweighted (self ):
170+ top_k_cat_acc_obj = accuracy_metrics .TopKCategoricalAccuracy (
171+ k = 1 , name = "top_k_categorical_accuracy" , dtype = "float32"
172+ )
173+ y_true = np .array ([[0 , 0 , 1 ], [0 , 1 , 0 ]])
174+ y_pred = np .array ([[0.1 , 0.9 , 0.8 ], [0.05 , 0.95 , 0 ]], dtype = "float32" )
175+ top_k_cat_acc_obj .update_state (y_true , y_pred )
176+ result = top_k_cat_acc_obj .result ()
177+ self .assertAllClose (result , 0.5 , atol = 1e-3 )
178+
179+ def test_weighted (self ):
180+ top_k_cat_acc_obj = accuracy_metrics .TopKCategoricalAccuracy (
181+ k = 1 , name = "top_k_categorical_accuracy" , dtype = "float32"
182+ )
183+ y_true = np .array ([[0 , 0 , 1 ], [0 , 1 , 0 ]])
184+ y_pred = np .array ([[0.1 , 0.9 , 0.8 ], [0.05 , 0.95 , 0 ]], dtype = "float32" )
185+ sample_weight = np .array ([0.7 , 0.3 ])
186+ top_k_cat_acc_obj .update_state (
187+ y_true , y_pred , sample_weight = sample_weight
188+ )
189+ result = top_k_cat_acc_obj .result ()
190+ self .assertAllClose (result , 0.3 , atol = 1e-3 )
191+
192+
193+ class SparseTopKCategoricalAccuracyTest (testing .TestCase ):
194+ def test_config (self ):
195+ sp_top_k_cat_acc_obj = accuracy_metrics .SparseTopKCategoricalAccuracy (
196+ k = 1 , name = "sparse_top_k_categorical_accuracy" , dtype = "float32"
197+ )
198+ self .assertEqual (
199+ sp_top_k_cat_acc_obj .name , "sparse_top_k_categorical_accuracy"
200+ )
201+ self .assertEqual (len (sp_top_k_cat_acc_obj .variables ), 2 )
202+ self .assertEqual (sp_top_k_cat_acc_obj ._dtype , "float32" )
203+
204+ # Test get_config
205+ sp_top_k_cat_acc_obj_config = sp_top_k_cat_acc_obj .get_config ()
206+ self .assertEqual (
207+ sp_top_k_cat_acc_obj_config ["name" ],
208+ "sparse_top_k_categorical_accuracy" ,
209+ )
210+ self .assertEqual (sp_top_k_cat_acc_obj_config ["dtype" ], "float32" )
211+ self .assertEqual (sp_top_k_cat_acc_obj_config ["k" ], 1 )
212+ # TODO: Check save and restore config
213+
214+ def test_unweighted (self ):
215+ sp_top_k_cat_acc_obj = accuracy_metrics .SparseTopKCategoricalAccuracy (
216+ k = 1 , name = "sparse_top_k_categorical_accuracy" , dtype = "float32"
217+ )
218+ y_true = np .array ([2 , 1 ])
219+ y_pred = np .array ([[0.1 , 0.9 , 0.8 ], [0.05 , 0.95 , 0 ]], dtype = "float32" )
220+ sp_top_k_cat_acc_obj .update_state (y_true , y_pred )
221+ result = sp_top_k_cat_acc_obj .result ()
222+ self .assertAllClose (result , 0.5 , atol = 1e-3 )
223+
224+ def test_weighted (self ):
225+ sp_top_k_cat_acc_obj = accuracy_metrics .SparseTopKCategoricalAccuracy (
226+ k = 1 , name = "sparse_top_k_categorical_accuracy" , dtype = "float32"
227+ )
228+ y_true = np .array ([2 , 1 ])
229+ y_pred = np .array ([[0.1 , 0.9 , 0.8 ], [0.05 , 0.95 , 0 ]], dtype = "float32" )
230+ sample_weight = np .array ([0.7 , 0.3 ])
231+ sp_top_k_cat_acc_obj .update_state (
232+ y_true , y_pred , sample_weight = sample_weight
233+ )
234+ result = sp_top_k_cat_acc_obj .result ()
235+ self .assertAllClose (result , 0.3 , atol = 1e-3 )
0 commit comments