diff --git a/java/fory-core/src/main/java/org/apache/fory/Fory.java b/java/fory-core/src/main/java/org/apache/fory/Fory.java index a2840ef5d3..70f59dda51 100644 --- a/java/fory-core/src/main/java/org/apache/fory/Fory.java +++ b/java/fory-core/src/main/java/org/apache/fory/Fory.java @@ -37,6 +37,7 @@ import org.apache.fory.config.ForyBuilder; import org.apache.fory.config.Language; import org.apache.fory.config.LongEncoding; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.io.ForyInputStream; import org.apache.fory.io.ForyReadableChannel; import org.apache.fory.logging.Logger; @@ -798,7 +799,7 @@ public Object deserialize(byte[] bytes) { public T deserialize(byte[] bytes, Class type) { generics.pushGenericType(classResolver.buildGenericType(type)); try { - return (T) deserialize(MemoryUtils.wrap(bytes), null); + return (T) checkedDeserialize(MemoryUtils.wrap(bytes), null, type); } finally { generics.popGenericType(); } @@ -834,6 +835,12 @@ public Object deserialize(MemoryBuffer buffer) { */ @Override public Object deserialize(MemoryBuffer buffer, Iterable outOfBandBuffers) { + return checkedDeserialize(buffer, outOfBandBuffers, null); + } + + private Object checkedDeserialize(MemoryBuffer buffer, + Iterable outOfBandBuffers, + Class expectedType) { try { jitContext.lock(); if (depth != 0) { @@ -878,7 +885,13 @@ public Object deserialize(MemoryBuffer buffer, Iterable outOfBandB } Object obj; if (isTargetXLang) { - obj = xreadRef(buffer); + if (expectedType != null) { + obj = checkedXreadRef(buffer, expectedType); + } else { + obj = xreadRef(buffer); + } + } else if (expectedType != null) { + obj = checkedReadRef(buffer, expectedType); } else { obj = readRef(buffer); } @@ -934,6 +947,26 @@ public Object readRef(MemoryBuffer buffer) { } } + private Object checkedReadRef(MemoryBuffer buffer, Class expectedType) { + RefResolver refResolver = this.refResolver; + int nextReadRefId = refResolver.tryPreserveRefId(buffer); + if (nextReadRefId >= NOT_NULL_VALUE_FLAG) { + // ref value or not-null value + ClassInfo classInfo = classResolver.readClassInfo(buffer); + if (!expectedType.isAssignableFrom(classInfo.getCls())) { + throw new DeserializationException(String.format( + "Unexpected type %s which is not assignable to %s", + classInfo.getClass().getName(), + expectedType.getName())); + } + Object o = readDataInternal(buffer, classInfo); + refResolver.setReadObject(nextReadRefId, o); + return o; + } else { + return refResolver.getReadObject(); + } + } + public Object readRef(MemoryBuffer buffer, ClassInfoHolder classInfoHolder) { RefResolver refResolver = this.refResolver; int nextReadRefId = refResolver.tryPreserveRefId(buffer); @@ -1070,6 +1103,26 @@ public Object xreadRef(MemoryBuffer buffer) { } } + private Object checkedXreadRef(MemoryBuffer buffer, Class expectedType) { + RefResolver refResolver = this.refResolver; + int nextReadRefId = refResolver.tryPreserveRefId(buffer); + if (nextReadRefId >= NOT_NULL_VALUE_FLAG) { + ClassInfo classInfo = xtypeResolver.readClassInfo(buffer); + if (!expectedType.isAssignableFrom(classInfo.getCls())) { + throw new DeserializationException(String.format( + "Unexpected type %s which is not assignable to %s", + classInfo.getClass().getName(), + expectedType.getName())); + } + Object o = xreadNonRef(buffer, classInfo); + refResolver.setReadObject(nextReadRefId, o); + return o; + } else { + return refResolver.getReadObject(); + } + } + + public Object xreadRef(MemoryBuffer buffer, Serializer serializer) { if (serializer.needToWriteRef()) { RefResolver refResolver = this.refResolver; diff --git a/java/fory-core/src/test/java/org/apache/fory/ForyTest.java b/java/fory-core/src/test/java/org/apache/fory/ForyTest.java index 4d8a10a75b..abecf269a6 100644 --- a/java/fory-core/src/test/java/org/apache/fory/ForyTest.java +++ b/java/fory-core/src/test/java/org/apache/fory/ForyTest.java @@ -58,6 +58,7 @@ import org.apache.fory.config.CompatibleMode; import org.apache.fory.config.ForyBuilder; import org.apache.fory.config.Language; +import org.apache.fory.exception.DeserializationException; import org.apache.fory.exception.ForyException; import org.apache.fory.exception.InsecureException; import org.apache.fory.memory.MemoryBuffer; @@ -673,4 +674,20 @@ public void testStructMapping() { Assert.assertEquals(struct1.f1, struct2.f1); Assert.assertEquals(struct1.f2, struct2.f2); } + + @Test + public void testCheckedDeserialize() { + Fory fory = Fory.builder() + .withLanguage(Language.JAVA) + .withRefTracking(true) + .requireClassRegistration(false) + .build(); + LocalDate now = LocalDate.now(); + byte[] bytes = fory.serialize(now); + LocalDate ld0 = fory.deserialize(bytes, LocalDate.class); + Assert.assertEquals(ld0, now); + // Deserialize with wrong type and get a DeserializationException + assertThrows(DeserializationException.class, + () -> fory.deserialize(bytes, String.class)); + } }