From 6070c67785ad53c07a66c98550001e480c176499 Mon Sep 17 00:00:00 2001 From: Josiah Glosson Date: Sat, 13 May 2023 08:25:38 -0500 Subject: [PATCH] J17 -> J8 improvements (#43) * Desugar record equals, hashCode, and toString * No more runtime libraries for J17 -> J8 I dropped Apache Commons IO by implementing transferTo, then implementing readAllBytes as a transferTo a ByteArrayOutputStream * Fix lists being constructed in reverse Well they still are, but they're reversed at the end. * Better Stream.toList conversion --- .../viaproxy/injection/Java17ToJava8.java | 394 ++++++++++++++++-- 1 file changed, 364 insertions(+), 30 deletions(-) diff --git a/src/main/java/net/raphimc/viaproxy/injection/Java17ToJava8.java b/src/main/java/net/raphimc/viaproxy/injection/Java17ToJava8.java index ca0a756..2ad6f9e 100644 --- a/src/main/java/net/raphimc/viaproxy/injection/Java17ToJava8.java +++ b/src/main/java/net/raphimc/viaproxy/injection/Java17ToJava8.java @@ -20,20 +20,44 @@ package net.raphimc.viaproxy.injection; import net.lenni0451.classtransform.TransformerManager; import net.lenni0451.classtransform.transformer.IBytecodeTransformer; import net.lenni0451.classtransform.utils.ASMUtils; +import net.raphimc.viaproxy.util.logging.Logger; +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; import org.objectweb.asm.Type; import org.objectweb.asm.tree.*; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.io.File; +import java.nio.file.Files; +import java.util.*; public class Java17ToJava8 implements IBytecodeTransformer { + private static final boolean DEBUG_DUMP = Boolean.getBoolean("viaproxy.debug.dump17to8"); + private static final char STACK_ARG_CONSTANT = '\u0001'; private static final char BSM_ARG_CONSTANT = '\u0002'; - final TransformerManager transformerManager; + private static final String TRANSFERTO_DESC = "(Ljava/io/InputStream;Ljava/io/OutputStream;)J"; + + private static final String EQUALS_DESC = "(Ljava/lang/Object;)Z"; + private static final String HASHCODE_DESC = "()I"; + private static final String TOSTRING_DESC = "()Ljava/lang/String;"; + private static final Map PRIMITIVE_WRAPPERS = new HashMap<>(); + + static { + PRIMITIVE_WRAPPERS.put("V", Type.getInternalName(Void.class)); + PRIMITIVE_WRAPPERS.put("Z", Type.getInternalName(Boolean.class)); + PRIMITIVE_WRAPPERS.put("B", Type.getInternalName(Byte.class)); + PRIMITIVE_WRAPPERS.put("S", Type.getInternalName(Short.class)); + PRIMITIVE_WRAPPERS.put("C", Type.getInternalName(Character.class)); + PRIMITIVE_WRAPPERS.put("I", Type.getInternalName(Integer.class)); + PRIMITIVE_WRAPPERS.put("F", Type.getInternalName(Float.class)); + PRIMITIVE_WRAPPERS.put("J", Type.getInternalName(Long.class)); + PRIMITIVE_WRAPPERS.put("D", Type.getInternalName(Double.class)); + } + + private final TransformerManager transformerManager; private final int nativeClassVersion; private final List whitelistedPackages = new ArrayList<>(); @@ -69,10 +93,20 @@ public class Java17ToJava8 implements IBytecodeTransformer { this.convertMapMethods(classNode); this.convertStreamMethods(classNode); this.convertMiscMethods(classNode); - this.removeRecords(classNode); + this.convertRecords(classNode); if (calculateStackMapFrames) { - return ASMUtils.toBytes(classNode, this.transformerManager.getClassTree(), this.transformerManager.getClassProvider()); + final byte[] result = ASMUtils.toBytes(classNode, this.transformerManager.getClassTree(), this.transformerManager.getClassProvider()); + if (DEBUG_DUMP) { + try { + final File file = new File("vp_17to8_dump", classNode.name + ".class"); + file.getParentFile().mkdirs(); + Files.write(file.toPath(), result); + } catch (Throwable e) { + Logger.LOGGER.error("Failed to dump class {}", className, e); + } + } + return result; } else { return ASMUtils.toStacklessBytes(classNode); } @@ -144,6 +178,13 @@ public class Java17ToJava8 implements IBytecodeTransformer { list.add(new InsnNode(Opcodes.POP)); } list.add(new VarInsnNode(Opcodes.ALOAD, freeVarIndex)); + list.add(new InsnNode(Opcodes.DUP)); + list.add(new MethodInsnNode( + Opcodes.INVOKESTATIC, + "java/util/Collections", + "reverse", + "(Ljava/util/List;)V" + )); list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "java/util/Collections", "unmodifiableList", "(Ljava/util/List;)Ljava/util/List;")); } } else if (min.name.equals("copyOf")) { @@ -268,16 +309,25 @@ public class Java17ToJava8 implements IBytecodeTransformer { final InsnList list = new InsnList(); if (min.name.equals("toList")) { - int freeVarIndex = ASMUtils.getFreeVarIndex(method); - list.add(new VarInsnNode(Opcodes.ASTORE, freeVarIndex)); - - list.add(new TypeInsnNode(Opcodes.NEW, "java/util/ArrayList")); - list.add(new InsnNode(Opcodes.DUP)); - list.add(new VarInsnNode(Opcodes.ALOAD, freeVarIndex)); - list.add(new MethodInsnNode(Opcodes.INVOKEINTERFACE, "java/util/stream/Stream", "toArray", "()[Ljava/lang/Object;")); - list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "java/util/Arrays", "asList", "([Ljava/lang/Object;)Ljava/util/List;")); - list.add(new MethodInsnNode(Opcodes.INVOKESPECIAL, "java/util/ArrayList", "", "(Ljava/util/Collection;)V")); - list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "java/util/Collections", "unmodifiableList", "(Ljava/util/List;)Ljava/util/List;")); + list.add(new MethodInsnNode( + Opcodes.INVOKESTATIC, + "java/util/stream/Collectors", + "toList", + "()Ljava/util/stream/Collector;" + )); + list.add(new MethodInsnNode( + Opcodes.INVOKEINTERFACE, + "java/util/stream/Stream", + "collect", + "(Ljava/util/stream/Collector;)Ljava/lang/Object;" + )); + list.add(new TypeInsnNode(Opcodes.CHECKCAST, "java/util/List")); + list.add(new MethodInsnNode( + Opcodes.INVOKESTATIC, + "java/util/Collections", + "unmodifiableList", + "(Ljava/util/List;)Ljava/util/List;") + ); } if (list.size() != 0) { @@ -290,6 +340,15 @@ public class Java17ToJava8 implements IBytecodeTransformer { } private void convertMiscMethods(final ClassNode node) { + boolean needsTransferTo = false; + String transferToName; + { + int i = 0; + do { + transferToName = "transferTo$" + i; + } while (ASMUtils.getMethod(node, transferToName, TRANSFERTO_DESC) != null); + } + for (MethodNode method : node.methods) { for (AbstractInsnNode insn : method.instructions.toArray()) { if (insn instanceof MethodInsnNode) { @@ -303,7 +362,36 @@ public class Java17ToJava8 implements IBytecodeTransformer { } } else if (min.owner.equals("java/io/InputStream")) { if (min.name.equals("readAllBytes") && min.getOpcode() == Opcodes.INVOKEVIRTUAL) { - list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, "org/apache/commons/io/IOUtils", "toByteArray", "(Ljava/io/InputStream;)[B")); + needsTransferTo = true; + list.add(new TypeInsnNode(Opcodes.NEW, "java/io/ByteArrayOutputStream")); + list.add(new InsnNode(Opcodes.DUP)); + list.add(new MethodInsnNode( + Opcodes.INVOKESPECIAL, + "java/io/ByteArrayOutputStream", + "", + "()V" + )); + list.add(new InsnNode(Opcodes.DUP_X1)); + list.add(new MethodInsnNode( + Opcodes.INVOKESTATIC, + node.name, + transferToName, + TRANSFERTO_DESC + )); + list.add(new InsnNode(Opcodes.POP2)); + list.add(new MethodInsnNode( + Opcodes.INVOKEVIRTUAL, + "java/io/ByteArrayOutputStream", + "toByteArray", + "()[B" + )); + } else if (min.name.equals("transferTo") && min.getOpcode() == Opcodes.INVOKEVIRTUAL) { + needsTransferTo = true; + list.add(new MethodInsnNode(Opcodes.INVOKESTATIC, + node.name, + transferToName, + TRANSFERTO_DESC + )); } } else if (min.owner.equals("java/nio/file/FileSystems")) { if (min.name.equals("newFileSystem") && min.desc.equals("(Ljava/nio/file/Path;Ljava/util/Map;Ljava/lang/ClassLoader;)Ljava/nio/file/FileSystem;")) { @@ -402,26 +490,272 @@ public class Java17ToJava8 implements IBytecodeTransformer { } } } + + if (needsTransferTo) { + // I compiled this by hand btw + final MethodVisitor transferTo = node.visitMethod( + Opcodes.ACC_PRIVATE | Opcodes.ACC_STATIC | Opcodes.ACC_SYNTHETIC, + transferToName, TRANSFERTO_DESC, null, new String[] {"java/io/IOException"} + ); + transferTo.visitCode(); + + // Objects.requireNonNull(out, "out"); + transferTo.visitVarInsn(Opcodes.ALOAD, 1); + transferTo.visitLdcInsn("out"); + transferTo.visitMethodInsn( + Opcodes.INVOKESTATIC, + "java/util/Objects", + "requireNonNull", + "(Ljava/lang/Object;Ljava/lang/String;)Ljava/lang/Object;", + false + ); + transferTo.visitInsn(Opcodes.POP); + + // long transferred = 0; + transferTo.visitInsn(Opcodes.LCONST_0); + transferTo.visitVarInsn(Opcodes.LSTORE, 2); + + // byte[] buffer = new byte[DEFAULT_BUFFER_SIZE]; + transferTo.visitIntInsn(Opcodes.SIPUSH, 8192); + transferTo.visitIntInsn(Opcodes.NEWARRAY, Opcodes.T_BYTE); + transferTo.visitVarInsn(Opcodes.ASTORE, 4); + + // while ((read = this.read(buffer, 0, DEFAULT_BUFFER_SIZE)) >= 0) { + final Label whileStart = new Label(); + final Label whileEnd = new Label(); + transferTo.visitLabel(whileStart); + transferTo.visitVarInsn(Opcodes.ALOAD, 0); + transferTo.visitVarInsn(Opcodes.ALOAD, 4); + transferTo.visitInsn(Opcodes.ICONST_0); + transferTo.visitIntInsn(Opcodes.SIPUSH, 8192); + transferTo.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + "java/io/InputStream", + "read", + "([BII)I", + false + ); + transferTo.visitInsn(Opcodes.DUP); + transferTo.visitVarInsn(Opcodes.ISTORE, 5); + transferTo.visitJumpInsn(Opcodes.IFLT, whileEnd); + + // out.write(buffer, 0, read); + transferTo.visitVarInsn(Opcodes.ALOAD, 1); + transferTo.visitVarInsn(Opcodes.ALOAD, 4); + transferTo.visitInsn(Opcodes.ICONST_0); + transferTo.visitVarInsn(Opcodes.ILOAD, 5); + transferTo.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + "java/io/OutputStream", + "write", + "([BII)V", + false + ); + + // transferred += read; + transferTo.visitVarInsn(Opcodes.LLOAD, 2); + transferTo.visitVarInsn(Opcodes.ILOAD, 5); + transferTo.visitInsn(Opcodes.I2L); + transferTo.visitInsn(Opcodes.LADD); + transferTo.visitVarInsn(Opcodes.LSTORE, 2); + + // } + transferTo.visitJumpInsn(Opcodes.GOTO, whileStart); + transferTo.visitLabel(whileEnd); + + // return transferred; + transferTo.visitVarInsn(Opcodes.LLOAD, 2); + transferTo.visitInsn(Opcodes.LRETURN); + + transferTo.visitEnd(); + } } - private void removeRecords(final ClassNode node) { - if (node.superName.equals("java/lang/Record")) { - node.access &= ~Opcodes.ACC_RECORD; - node.superName = "java/lang/Object"; + private void convertRecords(final ClassNode node) { + if (!node.superName.equals("java/lang/Record")) return; - List constructors = ASMUtils.getMethodsFromCombi(node, ""); - for (MethodNode method : constructors) { - for (AbstractInsnNode insn : method.instructions.toArray()) { - if (insn.getOpcode() == Opcodes.INVOKESPECIAL) { - MethodInsnNode min = (MethodInsnNode) insn; - if (min.owner.equals("java/lang/Record")) { - min.owner = "java/lang/Object"; - break; - } + node.access &= ~Opcodes.ACC_RECORD; + node.superName = "java/lang/Object"; + + final List constructors = ASMUtils.getMethodsFromCombi(node, ""); + for (MethodNode method : constructors) { + for (AbstractInsnNode insn : method.instructions.toArray()) { + if (insn.getOpcode() == Opcodes.INVOKESPECIAL) { + MethodInsnNode min = (MethodInsnNode) insn; + if (min.owner.equals("java/lang/Record")) { + min.owner = "java/lang/Object"; + break; } } } } + + node.methods.remove(ASMUtils.getMethod(node, "equals", EQUALS_DESC)); + final MethodVisitor equals = node.visitMethod(Opcodes.ACC_PUBLIC, "equals", EQUALS_DESC, null, null); + { + equals.visitCode(); + + equals.visitVarInsn(Opcodes.ALOAD, 0); + equals.visitVarInsn(Opcodes.ALOAD, 1); + final Label notSameLabel = new Label(); + equals.visitJumpInsn(Opcodes.IF_ACMPNE, notSameLabel); + equals.visitInsn(Opcodes.ICONST_1); + equals.visitInsn(Opcodes.IRETURN); + equals.visitLabel(notSameLabel); + + // Original uses Class.isInstance, but I think instanceof is more fitting here + equals.visitVarInsn(Opcodes.ALOAD, 1); + equals.visitTypeInsn(Opcodes.INSTANCEOF, node.name); + final Label notIsInstanceLabel = new Label(); + equals.visitJumpInsn(Opcodes.IFNE, notIsInstanceLabel); + equals.visitInsn(Opcodes.ICONST_0); + equals.visitInsn(Opcodes.IRETURN); + equals.visitLabel(notIsInstanceLabel); + + equals.visitVarInsn(Opcodes.ALOAD, 1); + equals.visitTypeInsn(Opcodes.CHECKCAST, node.name); + equals.visitVarInsn(Opcodes.ASTORE, 2); + + final Label notEqualLabel = new Label(); + for (final RecordComponentNode component : node.recordComponents) { + equals.visitVarInsn(Opcodes.ALOAD, 0); + equals.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor); + equals.visitVarInsn(Opcodes.ALOAD, 2); + equals.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor); + if (Type.getType(component.descriptor).getSort() >= Type.ARRAY) { // ARRAY or OBJECT + equals.visitMethodInsn( + Opcodes.INVOKESTATIC, + Type.getInternalName(Objects.class), + "equals", + "(Ljava/lang/Object;Ljava/lang/Object;)Z", + false + ); + equals.visitJumpInsn(Opcodes.IFEQ, notEqualLabel); + continue; + } else if ("BSCIZ".contains(component.descriptor)) { + equals.visitJumpInsn(Opcodes.IF_ICMPNE, notEqualLabel); + continue; + } else if (component.descriptor.equals("F")) { + equals.visitMethodInsn( + Opcodes.INVOKESTATIC, + Type.getInternalName(Float.class), + "equals", + "(FF)Z", + false + ); + } else if (component.descriptor.equals("D")) { + equals.visitMethodInsn( + Opcodes.INVOKESTATIC, + Type.getInternalName(Double.class), + "equals", + "(DD)Z", + false + ); + } else if (component.descriptor.equals("J")) { + equals.visitInsn(Opcodes.LCMP); + } else { + throw new AssertionError("Unknown descriptor " + component.descriptor); + } + equals.visitJumpInsn(Opcodes.IFNE, notEqualLabel); + } + equals.visitInsn(Opcodes.ICONST_1); + equals.visitInsn(Opcodes.IRETURN); + equals.visitLabel(notEqualLabel); + equals.visitInsn(Opcodes.ICONST_0); + equals.visitInsn(Opcodes.IRETURN); + + equals.visitEnd(); + } + + node.methods.remove(ASMUtils.getMethod(node, "hashCode", HASHCODE_DESC)); + final MethodVisitor hashCode = node.visitMethod(Opcodes.ACC_PUBLIC, "hashCode", HASHCODE_DESC, null, null); + { + hashCode.visitCode(); + + hashCode.visitInsn(Opcodes.ICONST_0); + for (final RecordComponentNode component : node.recordComponents) { + hashCode.visitIntInsn(Opcodes.BIPUSH, 31); + hashCode.visitInsn(Opcodes.IMUL); + hashCode.visitVarInsn(Opcodes.ALOAD, 0); + hashCode.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor); + final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor); + hashCode.visitMethodInsn( + Opcodes.INVOKESTATIC, + owner != null ? owner : "java/util/Objects", + "hashCode", + "(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")I", + false + ); + hashCode.visitInsn(Opcodes.IADD); + } + hashCode.visitInsn(Opcodes.IRETURN); + + hashCode.visitEnd(); + } + + node.methods.remove(ASMUtils.getMethod(node, "toString", TOSTRING_DESC)); + final MethodVisitor toString = node.visitMethod(Opcodes.ACC_PUBLIC, "toString", TOSTRING_DESC, null, null); + { + toString.visitCode(); + + final StringBuilder formatString = new StringBuilder("%s["); + for (int i = 0; i < node.recordComponents.size(); i++) { + formatString.append(node.recordComponents.get(i).name).append("=%s"); + if (i != node.recordComponents.size() - 1) { + formatString.append(", "); + } + } + formatString.append(']'); + + toString.visitLdcInsn(formatString.toString()); + toString.visitIntInsn(Opcodes.SIPUSH, node.recordComponents.size() + 1); + toString.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object"); + toString.visitInsn(Opcodes.DUP); + toString.visitInsn(Opcodes.ICONST_0); + toString.visitVarInsn(Opcodes.ALOAD, 0); + toString.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + "java/lang/Object", + "getClass", + "()Ljava/lang/Class;", + false + ); + toString.visitMethodInsn( + Opcodes.INVOKEVIRTUAL, + "java/lang/Class", + "getSimpleName", + "()Ljava/lang/String;", + false + ); + toString.visitInsn(Opcodes.AASTORE); + int i = 1; + for (final RecordComponentNode component : node.recordComponents) { + toString.visitInsn(Opcodes.DUP); + toString.visitIntInsn(Opcodes.SIPUSH, i); + toString.visitVarInsn(Opcodes.ALOAD, 0); + toString.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor); + final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor); + toString.visitMethodInsn( + Opcodes.INVOKESTATIC, + owner != null ? owner : "java/util/Objects", + "toString", + "(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")Ljava/lang/String;", + false + ); + toString.visitInsn(Opcodes.AASTORE); + i++; + } + toString.visitMethodInsn( + Opcodes.INVOKESTATIC, + "java/lang/String", + "format", + "(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;", + false + ); + toString.visitInsn(Opcodes.ARETURN); + + toString.visitEnd(); + } } private int count(final String s, final char search) {