diff --git a/src/main/java/net/elytrium/velocitytools/hooks/HooksInitializer.java b/src/main/java/net/elytrium/velocitytools/hooks/HooksInitializer.java index 2f58cd6..e6ce7ce 100644 --- a/src/main/java/net/elytrium/velocitytools/hooks/HooksInitializer.java +++ b/src/main/java/net/elytrium/velocitytools/hooks/HooksInitializer.java @@ -33,7 +33,7 @@ import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.reflect.Field; -import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.function.BiConsumer; @@ -68,36 +68,9 @@ public static void init(ProxyServer server) { MethodHandle versionsField = MethodHandles.privateLookupIn(StateRegistry.PacketRegistry.class, MethodHandles.lookup()) .findGetter(StateRegistry.PacketRegistry.class, "versions", Map.class); - MethodHandle packetIdToSupplierField = MethodHandles - .privateLookupIn(StateRegistry.PacketRegistry.ProtocolRegistry.class, MethodHandles.lookup()) - .findGetter(StateRegistry.PacketRegistry.ProtocolRegistry.class, "packetIdToSupplier", IntObjectMap.class); - - MethodHandle packetClassToIdField = MethodHandles - .privateLookupIn(StateRegistry.PacketRegistry.ProtocolRegistry.class, MethodHandles.lookup()) - .findGetter(StateRegistry.PacketRegistry.ProtocolRegistry.class, "packetClassToId", Object2IntMap.class); - - List hooks = new ArrayList<>(); - hooks.add(new PluginMessageHook()); - hooks.add(new HandshakeHook()); - - BiConsumer consumer = (version, registry) -> { - try { - IntObjectMap> packetIdToSupplier - = (IntObjectMap>) packetIdToSupplierField.invoke(registry); - - Object2IntMap> packetClassToId - = (Object2IntMap>) packetClassToIdField.invoke(registry); - - hooks.forEach(hook -> { - int packetId = packetClassToId.getInt(hook.getType()); - packetClassToId.put(hook.getHookClass(), packetId); - packetIdToSupplier.remove(packetId); - packetIdToSupplier.put(packetId, hook.getHook()); - }); - } catch (Throwable e) { - throw new ReflectionException(e); - } - }; + Map> hooks = new LinkedHashMap<>(); + hooks.put(StateRegistry.PLAY, List.of(new PluginMessageHook())); + hooks.put(StateRegistry.HANDSHAKE, List.of(new HandshakeHook())); MethodHandle clientboundGetter = MethodHandles.privateLookupIn(StateRegistry.class, MethodHandles.lookup()) .findGetter(StateRegistry.class, "clientbound", StateRegistry.PacketRegistry.class); @@ -109,11 +82,44 @@ public static void init(ProxyServer server) { StateRegistry.PacketRegistry configClientbound = (StateRegistry.PacketRegistry) clientboundGetter.invokeExact(StateRegistry.CONFIG); StateRegistry.PacketRegistry handshakeServerbound = (StateRegistry.PacketRegistry) serverboundGetter.invokeExact(StateRegistry.HANDSHAKE); - ((Map) versionsField.invokeExact(playClientbound)).forEach(consumer); - ((Map) versionsField.invokeExact(configClientbound)).forEach(consumer); - ((Map) versionsField.invokeExact(handshakeServerbound)).forEach(consumer); + ((Map) versionsField.invokeExact(playClientbound)) + .forEach(hookInitializer(hooks.getOrDefault(StateRegistry.PLAY, List.of()))); + ((Map) versionsField.invokeExact(configClientbound)) + .forEach(hookInitializer(hooks.getOrDefault(StateRegistry.CONFIG, List.of()))); + ((Map) versionsField.invokeExact(handshakeServerbound)) + .forEach(hookInitializer(hooks.getOrDefault(StateRegistry.HANDSHAKE, List.of()))); } catch (Throwable e) { throw new ReflectionException(e); } } + + private static BiConsumer hookInitializer(List hooks) + throws IllegalAccessException, NoSuchFieldException { + MethodHandle packetIdToSupplierField = MethodHandles + .privateLookupIn(StateRegistry.PacketRegistry.ProtocolRegistry.class, MethodHandles.lookup()) + .findGetter(StateRegistry.PacketRegistry.ProtocolRegistry.class, "packetIdToSupplier", IntObjectMap.class); + + MethodHandle packetClassToIdField = MethodHandles + .privateLookupIn(StateRegistry.PacketRegistry.ProtocolRegistry.class, MethodHandles.lookup()) + .findGetter(StateRegistry.PacketRegistry.ProtocolRegistry.class, "packetClassToId", Object2IntMap.class); + + return (version, registry) -> { + try { + IntObjectMap> packetIdToSupplier + = (IntObjectMap>) packetIdToSupplierField.invoke(registry); + + Object2IntMap> packetClassToId + = (Object2IntMap>) packetClassToIdField.invoke(registry); + + hooks.forEach(hook -> { + int packetId = packetClassToId.getInt(hook.getType()); + packetClassToId.put(hook.getHookClass(), packetId); + packetIdToSupplier.remove(packetId); + packetIdToSupplier.put(packetId, hook.getHook()); + }); + } catch (Throwable e) { + throw new ReflectionException(e); + } + }; + } }