Desugar record equals, hashCode, and toString

This commit is contained in:
Josiah (Gaming32) Glosson 2023-05-12 20:59:08 -05:00
parent 565690fcb1
commit d95f3a277a

View File

@ -20,20 +20,42 @@ package net.raphimc.viaproxy.injection;
import net.lenni0451.classtransform.TransformerManager;
import net.lenni0451.classtransform.transformer.IBytecodeTransformer;
import net.lenni0451.classtransform.utils.ASMUtils;
import net.raphimc.viaproxy.util.logging.Logger;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.io.File;
import java.nio.file.Files;
import java.util.*;
public class Java17ToJava8 implements IBytecodeTransformer {
private static final boolean DEBUG_DUMP = Boolean.getBoolean("viaproxy.debug.dump17to8");
private static final char STACK_ARG_CONSTANT = '\u0001';
private static final char BSM_ARG_CONSTANT = '\u0002';
final TransformerManager transformerManager;
private static final String EQUALS_DESC = "(Ljava/lang/Object;)Z";
private static final String HASHCODE_DESC = "()I";
private static final String TOSTRING_DESC = "()Ljava/lang/String;";
private static final Map<String, String> PRIMITIVE_WRAPPERS = new HashMap<>();
static {
PRIMITIVE_WRAPPERS.put("V", Type.getInternalName(Void.class));
PRIMITIVE_WRAPPERS.put("Z", Type.getInternalName(Boolean.class));
PRIMITIVE_WRAPPERS.put("B", Type.getInternalName(Byte.class));
PRIMITIVE_WRAPPERS.put("S", Type.getInternalName(Short.class));
PRIMITIVE_WRAPPERS.put("C", Type.getInternalName(Character.class));
PRIMITIVE_WRAPPERS.put("I", Type.getInternalName(Integer.class));
PRIMITIVE_WRAPPERS.put("F", Type.getInternalName(Float.class));
PRIMITIVE_WRAPPERS.put("J", Type.getInternalName(Long.class));
PRIMITIVE_WRAPPERS.put("D", Type.getInternalName(Double.class));
}
private final TransformerManager transformerManager;
private final int nativeClassVersion;
private final List<String> whitelistedPackages = new ArrayList<>();
@ -69,10 +91,20 @@ public class Java17ToJava8 implements IBytecodeTransformer {
this.convertMapMethods(classNode);
this.convertStreamMethods(classNode);
this.convertMiscMethods(classNode);
this.removeRecords(classNode);
this.convertRecords(classNode);
if (calculateStackMapFrames) {
return ASMUtils.toBytes(classNode, this.transformerManager.getClassTree(), this.transformerManager.getClassProvider());
final byte[] result = ASMUtils.toBytes(classNode, this.transformerManager.getClassTree(), this.transformerManager.getClassProvider());
if (DEBUG_DUMP) {
try {
final File file = new File("vp_17to8_dump", classNode.name + ".class");
file.getParentFile().mkdirs();
Files.write(file.toPath(), result);
} catch (Throwable e) {
Logger.LOGGER.error("Failed to dump class {}", className, e);
}
}
return result;
} else {
return ASMUtils.toStacklessBytes(classNode);
}
@ -404,24 +436,191 @@ public class Java17ToJava8 implements IBytecodeTransformer {
}
}
private void removeRecords(final ClassNode node) {
if (node.superName.equals("java/lang/Record")) {
node.access &= ~Opcodes.ACC_RECORD;
node.superName = "java/lang/Object";
private void convertRecords(final ClassNode node) {
if (!node.superName.equals("java/lang/Record")) return;
List<MethodNode> constructors = ASMUtils.getMethodsFromCombi(node, "<init>");
for (MethodNode method : constructors) {
for (AbstractInsnNode insn : method.instructions.toArray()) {
if (insn.getOpcode() == Opcodes.INVOKESPECIAL) {
MethodInsnNode min = (MethodInsnNode) insn;
if (min.owner.equals("java/lang/Record")) {
min.owner = "java/lang/Object";
break;
}
node.access &= ~Opcodes.ACC_RECORD;
node.superName = "java/lang/Object";
final List<MethodNode> constructors = ASMUtils.getMethodsFromCombi(node, "<init>");
for (MethodNode method : constructors) {
for (AbstractInsnNode insn : method.instructions.toArray()) {
if (insn.getOpcode() == Opcodes.INVOKESPECIAL) {
MethodInsnNode min = (MethodInsnNode) insn;
if (min.owner.equals("java/lang/Record")) {
min.owner = "java/lang/Object";
break;
}
}
}
}
node.methods.remove(ASMUtils.getMethod(node, "equals", EQUALS_DESC));
final MethodVisitor equals = node.visitMethod(Opcodes.ACC_PUBLIC, "equals", EQUALS_DESC, null, null);
{
equals.visitCode();
equals.visitVarInsn(Opcodes.ALOAD, 0);
equals.visitVarInsn(Opcodes.ALOAD, 1);
final Label notSameLabel = new Label();
equals.visitJumpInsn(Opcodes.IF_ACMPNE, notSameLabel);
equals.visitInsn(Opcodes.ICONST_1);
equals.visitInsn(Opcodes.IRETURN);
equals.visitLabel(notSameLabel);
// Original uses Class.isInstance, but I think instanceof is more fitting here
equals.visitVarInsn(Opcodes.ALOAD, 1);
equals.visitTypeInsn(Opcodes.INSTANCEOF, node.name);
final Label notIsInstanceLabel = new Label();
equals.visitJumpInsn(Opcodes.IFNE, notIsInstanceLabel);
equals.visitInsn(Opcodes.ICONST_0);
equals.visitInsn(Opcodes.IRETURN);
equals.visitLabel(notIsInstanceLabel);
equals.visitVarInsn(Opcodes.ALOAD, 1);
equals.visitTypeInsn(Opcodes.CHECKCAST, node.name);
equals.visitVarInsn(Opcodes.ASTORE, 2);
final Label notEqualLabel = new Label();
for (final RecordComponentNode component : node.recordComponents) {
equals.visitVarInsn(Opcodes.ALOAD, 0);
equals.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor);
equals.visitVarInsn(Opcodes.ALOAD, 2);
equals.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor);
if (Type.getType(component.descriptor).getSort() >= Type.ARRAY) { // ARRAY or OBJECT
equals.visitMethodInsn(
Opcodes.INVOKESTATIC,
Type.getInternalName(Objects.class),
"equals",
"(Ljava/lang/Object;Ljava/lang/Object;)Z",
false
);
equals.visitJumpInsn(Opcodes.IFEQ, notEqualLabel);
continue;
} else if ("BSCIZ".contains(component.descriptor)) {
equals.visitJumpInsn(Opcodes.IF_ICMPNE, notEqualLabel);
continue;
} else if (component.descriptor.equals("F")) {
equals.visitMethodInsn(
Opcodes.INVOKESTATIC,
Type.getInternalName(Float.class),
"equals",
"(FF)Z",
false
);
} else if (component.descriptor.equals("D")) {
equals.visitMethodInsn(
Opcodes.INVOKESTATIC,
Type.getInternalName(Double.class),
"equals",
"(DD)Z",
false
);
} else if (component.descriptor.equals("J")) {
equals.visitInsn(Opcodes.LCMP);
} else {
throw new AssertionError("Unknown descriptor " + component.descriptor);
}
equals.visitJumpInsn(Opcodes.IFNE, notEqualLabel);
}
equals.visitInsn(Opcodes.ICONST_1);
equals.visitInsn(Opcodes.IRETURN);
equals.visitLabel(notEqualLabel);
equals.visitInsn(Opcodes.ICONST_0);
equals.visitInsn(Opcodes.IRETURN);
equals.visitEnd();
}
node.methods.remove(ASMUtils.getMethod(node, "hashCode", HASHCODE_DESC));
final MethodVisitor hashCode = node.visitMethod(Opcodes.ACC_PUBLIC, "hashCode", HASHCODE_DESC, null, null);
{
hashCode.visitCode();
hashCode.visitInsn(Opcodes.ICONST_0);
for (final RecordComponentNode component : node.recordComponents) {
hashCode.visitIntInsn(Opcodes.BIPUSH, 31);
hashCode.visitInsn(Opcodes.IMUL);
hashCode.visitVarInsn(Opcodes.ALOAD, 0);
hashCode.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor);
final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor);
hashCode.visitMethodInsn(
Opcodes.INVOKESTATIC,
owner != null ? owner : "java/util/Objects",
"hashCode",
"(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")I",
false
);
hashCode.visitInsn(Opcodes.IADD);
}
hashCode.visitInsn(Opcodes.IRETURN);
hashCode.visitEnd();
}
node.methods.remove(ASMUtils.getMethod(node, "toString", TOSTRING_DESC));
final MethodVisitor toString = node.visitMethod(Opcodes.ACC_PUBLIC, "toString", TOSTRING_DESC, null, null);
{
toString.visitCode();
final StringBuilder formatString = new StringBuilder("%s[");
for (int i = 0; i < node.recordComponents.size(); i++) {
formatString.append(node.recordComponents.get(i).name).append("=%s");
if (i != node.recordComponents.size() - 1) {
formatString.append(", ");
}
}
formatString.append(']');
toString.visitLdcInsn(formatString.toString());
toString.visitIntInsn(Opcodes.SIPUSH, node.recordComponents.size() + 1);
toString.visitTypeInsn(Opcodes.ANEWARRAY, "java/lang/Object");
toString.visitInsn(Opcodes.DUP);
toString.visitInsn(Opcodes.ICONST_0);
toString.visitVarInsn(Opcodes.ALOAD, 0);
toString.visitMethodInsn(
Opcodes.INVOKEVIRTUAL,
"java/lang/Object",
"getClass",
"()Ljava/lang/Class;",
false
);
toString.visitMethodInsn(
Opcodes.INVOKEVIRTUAL,
"java/lang/Class",
"getSimpleName",
"()Ljava/lang/String;",
false
);
toString.visitInsn(Opcodes.AASTORE);
int i = 1;
for (final RecordComponentNode component : node.recordComponents) {
toString.visitInsn(Opcodes.DUP);
toString.visitIntInsn(Opcodes.SIPUSH, i);
toString.visitVarInsn(Opcodes.ALOAD, 0);
toString.visitFieldInsn(Opcodes.GETFIELD, node.name, component.name, component.descriptor);
final String owner = PRIMITIVE_WRAPPERS.get(component.descriptor);
toString.visitMethodInsn(
Opcodes.INVOKESTATIC,
owner != null ? owner : "java/util/Objects",
"toString",
"(" + (owner != null ? component.descriptor : "Ljava/lang/Object;") + ")Ljava/lang/String;",
false
);
toString.visitInsn(Opcodes.AASTORE);
i++;
}
toString.visitMethodInsn(
Opcodes.INVOKESTATIC,
"java/lang/String",
"format",
"(Ljava/lang/String;[Ljava/lang/Object;)Ljava/lang/String;",
false
);
toString.visitInsn(Opcodes.ARETURN);
toString.visitEnd();
}
}
private int count(final String s, final char search) {