diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py index 64609ab9dd9a62..760bb5096bef2d 100644 --- a/Lib/test/test_shelve.py +++ b/Lib/test/test_shelve.py @@ -173,6 +173,8 @@ def test_custom_serializer_and_deserializer(self): def serializer(obj, protocol): if isinstance(obj, (bytes, bytearray, str)): if protocol == 5: + if isinstance(obj, bytearray): + return bytes(obj) return obj return type(obj).__name__ elif isinstance(obj, array.array): @@ -222,22 +224,25 @@ def deserializer(data): s["array_data"], array_data.tobytes().decode() ) - def test_custom_incomplete_serializer_and_deserializer(self): - dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") + def test_custom_incomplete_serializer(self): os.mkdir(self.dirname) self.addCleanup(os_helper.rmtree, self.dirname) - with self.assertRaises(dbm_sqlite3.error): - def serializer(obj, protocol=None): - pass + def serializer(obj, protocol=None): + pass - def deserializer(data): - return data.decode("utf-8") + def deserializer(data): + return data.decode("utf-8") + with self.assertRaises((TypeError, dbm.error)): with shelve.open(self.fn, serializer=serializer, deserializer=deserializer) as s: s["foo"] = "bar" + def test_custom_incomplete_deserializer(self): + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + def serializer(obj, protocol=None): return type(obj).__name__.encode("utf-8") @@ -352,7 +357,7 @@ def type_name_len(obj): self.assertEqual(s["bytearray_data"], "bytearray") self.assertEqual(s["array_data"], "array") - def test_custom_incomplete_serializer_and_deserializer_bsd_db_shelf(self): + def test_custom_incomplete_deserializer_bsd_db_shelf(self): berkeleydb = import_helper.import_module("berkeleydb") os.mkdir(self.dirname) self.addCleanup(os_helper.rmtree, self.dirname) @@ -370,6 +375,11 @@ def deserializer(data): self.assertIsNone(s["foo"]) self.assertNotEqual(s["foo"], "bar") + def test_custom_incomplete_serializer_bsd_db_shelf(self): + berkeleydb = import_helper.import_module("berkeleydb") + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + def serializer(obj, protocol=None): pass @@ -399,6 +409,28 @@ def deserializer(data): self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs) self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs) + def test_custom_serializer_returns_wrong_type_for_key(self): + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def serializer(obj, protocol): + # Return None instead of bytes, which is wrong for dbm keys + return None + + def deserializer(data): + return data.decode("utf-8") if data else "" + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto), shelve.open( + self.fn, + protocol=proto, + serializer=serializer, + deserializer=deserializer + ) as s: + # Serializer returns None for the value, but dbm expects bytes + with self.assertRaises((TypeError, dbm.error)): + s["foo"] = "bar" + class TestShelveBase: type2test = shelve.Shelf