package io.shiftleft.bctrace.asm;

import io.shiftleft.bctrace.Bctrace;
import io.shiftleft.bctrace.asm.util.ASMUtils;
import io.shiftleft.bctrace.hook.Hook;
import io.shiftleft.bctrace.logging.Level;
import io.shiftleft.bctrace.runtime.listener.direct.DirectListener;
import io.shiftleft.bctrace.util.Utils;
import java.lang.instrument.ClassFileTransformer;
import java.lang.instrument.IllegalClassFormatException;
import java.lang.reflect.Method;
import java.security.ProtectionDomain;
import java.util.HashSet;
import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Type;
import org.objectweb.asm.tree.ClassNode;
import org.objectweb.asm.tree.FieldInsnNode;
import org.objectweb.asm.tree.FrameNode;
import org.objectweb.asm.tree.InsnList;
import org.objectweb.asm.tree.InsnNode;
import org.objectweb.asm.tree.JumpInsnNode;
import org.objectweb.asm.tree.LabelNode;
import org.objectweb.asm.tree.LocalVariableNode;
import org.objectweb.asm.tree.MethodInsnNode;
import org.objectweb.asm.tree.MethodNode;
import org.objectweb.asm.tree.TryCatchBlockNode;
import org.objectweb.asm.tree.TypeInsnNode;
import org.objectweb.asm.tree.VarInsnNode;

/* loaded from: input_file:io/shiftleft/bctrace/asm/CallbackTransformer.class */
public class CallbackTransformer implements ClassFileTransformer {
    private static final String CALLBACK_JVM_CLASS_NAME = "io/shiftleft/bctrace/runtime/Callback";
    private final Hook[] hooks;
    private volatile boolean completed = false;

    public CallbackTransformer(Hook[] hookArr) {
        this.hooks = hookArr;
    }

    public byte[] transform(ClassLoader classLoader, String str, Class<?> cls, ProtectionDomain protectionDomain, byte[] bArr) throws IllegalClassFormatException {
        if (str == null || !str.equals(CALLBACK_JVM_CLASS_NAME)) {
            return null;
        }
        try {
            HashSet<DirectListener> hashSet = new HashSet();
            for (int i = 0; i < this.hooks.length; i++) {
                if (this.hooks[i].getListener() instanceof DirectListener) {
                    hashSet.add(this.hooks[i].getListener());
                }
            }
            if (hashSet.size() == 0) {
                this.completed = true;
                return null;
            }
            ClassReader classReader = new ClassReader(ClassLoader.getSystemResourceAsStream("io/shiftleft/bctrace/runtime/Callback.class"));
            ClassNode classNode = new ClassNode();
            classNode.version = 50;
            classReader.accept(classNode, 0);
            HashSet hashSet2 = new HashSet();
            for (DirectListener directListener : hashSet) {
                String dynamicListenerMethodName = getDynamicListenerMethodName(directListener);
                if (!hashSet2.contains(dynamicListenerMethodName)) {
                    hashSet2.add(dynamicListenerMethodName);
                    classNode.methods.add(directListener.getListenerMethod().getReturnType().getName().equals("void") ? createVoidListenerMethod(directListener) : createMutableListenerCallbackMethod(directListener));
                }
            }
            ClassWriter classWriter = new ClassWriter(classReader, 1);
            classNode.accept(classWriter);
            this.completed = true;
            return classWriter.toByteArray();
        } catch (Throwable th) {
            Bctrace.getAgentLogger().setLevel(Level.ERROR);
            Bctrace.getAgentLogger().log(Level.ERROR, "Error found instrumenting class " + str, th);
            th.printStackTrace();
            return null;
        }
    }

    public static String getDynamicListenerMethodName(DirectListener directListener) {
        return directListener.getClass().getName().replace('.', '_') + "_" + directListener.getListenerMethod().getName();
    }

    private static MethodNode createVoidListenerMethod(DirectListener directListener) {
        Method listenerMethod = directListener.getListenerMethod();
        MethodNode methodNode = new MethodNode(4105, getDynamicListenerMethodName(directListener), (String) null, (String) null, (String[]) null);
        updateVoidDescriptor(methodNode, listenerMethod);
        InsnList insnList = methodNode.instructions;
        insnList.add(new MethodInsnNode(184, "io/shiftleft/bctrace/runtime/CallbackEnabler", "isThreadNotificationEnabled", "()Z", false));
        LabelNode labelNode = new LabelNode();
        insnList.add(new JumpInsnNode(154, labelNode));
        insnList.add(new InsnNode(177));
        insnList.add(labelNode);
        insnList.add(new FrameNode(3, 0, (Object[]) null, 0, (Object[]) null));
        insnList.add(new FieldInsnNode(178, "java/lang/Boolean", "TRUE", "Ljava/lang/Boolean;"));
        insnList.add(new FieldInsnNode(178, CALLBACK_JVM_CLASS_NAME, "NOTIFYING_FLAG", "Ljava/lang/ThreadLocal;"));
        insnList.add(new MethodInsnNode(182, "java/lang/ThreadLocal", "get", "()Ljava/lang/Object;", false));
        LabelNode labelNode2 = new LabelNode();
        insnList.add(new JumpInsnNode(166, labelNode2));
        insnList.add(new InsnNode(177));
        insnList.add(labelNode2);
        insnList.add(new FrameNode(3, 0, (Object[]) null, 0, (Object[]) null));
        LabelNode labelNode3 = new LabelNode();
        LabelNode labelNode4 = new LabelNode();
        LabelNode labelNode5 = new LabelNode();
        methodNode.tryCatchBlocks.add(new TryCatchBlockNode(labelNode3, labelNode4, labelNode5, (String) null));
        insnList.add(labelNode3);
        addSetNotifyingFlagInstructions(insnList, "TRUE");
        insnList.add(new FieldInsnNode(178, CALLBACK_JVM_CLASS_NAME, "listeners", "[Ljava/lang/Object;"));
        insnList.add(new VarInsnNode(21, 0));
        insnList.add(new InsnNode(50));
        String jvmInterfaceNameForDirectListener = Utils.getJvmInterfaceNameForDirectListener(listenerMethod.getDeclaringClass().getName());
        insnList.add(new TypeInsnNode(192, jvmInterfaceNameForDirectListener));
        Class<?>[] parameterTypes = listenerMethod.getParameterTypes();
        StringBuilder sb = new StringBuilder("(");
        for (int i = 0; i < parameterTypes.length; i++) {
            Type type = Type.getType(parameterTypes[i]);
            insnList.add(ASMUtils.getLoadInst(type, i + 1));
            sb.append(type.getDescriptor());
        }
        sb.append(")V");
        insnList.add(new MethodInsnNode(185, jvmInterfaceNameForDirectListener, listenerMethod.getName(), sb.toString(), true));
        addSetNotifyingFlagInstructions(insnList, "FALSE");
        insnList.add(new InsnNode(177));
        insnList.add(labelNode4);
        insnList.add(labelNode5);
        insnList.add(new FrameNode(4, 0, (Object[]) null, 1, new Object[]{"java/lang/Throwable"}));
        insnList.add(new MethodInsnNode(184, CALLBACK_JVM_CLASS_NAME, "handleThrowable", "(Ljava/lang/Throwable;)V", false));
        addSetNotifyingFlagInstructions(insnList, "FALSE");
        insnList.add(new InsnNode(177));
        return methodNode;
    }

    private MethodNode createMutableListenerCallbackMethod(DirectListener directListener) {
        Method listenerMethod = directListener.getListenerMethod();
        Type returnType = Type.getReturnType(listenerMethod);
        MethodNode methodNode = new MethodNode(4105, getDynamicListenerMethodName(directListener), (String) null, (String) null, (String[]) null);
        updateMutableDescriptor(methodNode, listenerMethod);
        InsnList insnList = methodNode.instructions;
        insnList.add(new MethodInsnNode(184, "io/shiftleft/bctrace/runtime/CallbackEnabler", "isThreadNotificationEnabled", "()Z", false));
        LabelNode labelNode = new LabelNode();
        insnList.add(new JumpInsnNode(154, labelNode));
        insnList.add(ASMUtils.getLoadInst(returnType, 1));
        insnList.add(ASMUtils.getReturnInst(returnType));
        insnList.add(labelNode);
        insnList.add(new FrameNode(3, 0, (Object[]) null, 0, (Object[]) null));
        insnList.add(new FieldInsnNode(178, "java/lang/Boolean", "TRUE", "Ljava/lang/Boolean;"));
        insnList.add(new FieldInsnNode(178, CALLBACK_JVM_CLASS_NAME, "NOTIFYING_FLAG", "Ljava/lang/ThreadLocal;"));
        insnList.add(new MethodInsnNode(182, "java/lang/ThreadLocal", "get", "()Ljava/lang/Object;", false));
        LabelNode labelNode2 = new LabelNode();
        insnList.add(new JumpInsnNode(166, labelNode2));
        insnList.add(ASMUtils.getLoadInst(returnType, 1));
        insnList.add(ASMUtils.getReturnInst(returnType));
        insnList.add(labelNode2);
        insnList.add(new FrameNode(3, 0, (Object[]) null, 0, (Object[]) null));
        int computeInitialMaxLocals = computeInitialMaxLocals(methodNode.desc);
        insnList.add(ASMUtils.getLoadInst(returnType, 1));
        insnList.add(ASMUtils.getStoreInst(returnType, computeInitialMaxLocals));
        LabelNode labelNode3 = new LabelNode();
        LabelNode labelNode4 = new LabelNode();
        LabelNode labelNode5 = new LabelNode();
        methodNode.tryCatchBlocks.add(new TryCatchBlockNode(labelNode3, labelNode4, labelNode5, (String) null));
        insnList.add(labelNode3);
        addSetNotifyingFlagInstructions(insnList, "TRUE");
        insnList.add(new FieldInsnNode(178, CALLBACK_JVM_CLASS_NAME, "listeners", "[Ljava/lang/Object;"));
        insnList.add(new VarInsnNode(21, 0));
        insnList.add(new InsnNode(50));
        String jvmInterfaceNameForDirectListener = Utils.getJvmInterfaceNameForDirectListener(listenerMethod.getDeclaringClass().getName());
        insnList.add(new TypeInsnNode(192, jvmInterfaceNameForDirectListener));
        Class<?>[] parameterTypes = listenerMethod.getParameterTypes();
        StringBuilder sb = new StringBuilder("(");
        for (int i = 0; i < parameterTypes.length; i++) {
            Type type = Type.getType(parameterTypes[i]);
            insnList.add(ASMUtils.getLoadInst(type, i + 2));
            sb.append(type.getDescriptor());
        }
        sb.append(")").append(returnType.getDescriptor());
        insnList.add(new MethodInsnNode(185, jvmInterfaceNameForDirectListener, listenerMethod.getName(), sb.toString(), true));
        insnList.add(ASMUtils.getStoreInst(returnType, computeInitialMaxLocals));
        addSetNotifyingFlagInstructions(insnList, "FALSE");
        insnList.add(ASMUtils.getLoadInst(returnType, computeInitialMaxLocals));
        insnList.add(ASMUtils.getReturnInst(returnType));
        insnList.add(labelNode4);
        insnList.add(labelNode5);
        insnList.add(new FrameNode(4, 0, (Object[]) null, 1, new Object[]{"java/lang/Throwable"}));
        insnList.add(new MethodInsnNode(184, CALLBACK_JVM_CLASS_NAME, "handleThrowable", "(Ljava/lang/Throwable;)V", false));
        addSetNotifyingFlagInstructions(insnList, "FALSE");
        insnList.add(ASMUtils.getLoadInst(returnType, computeInitialMaxLocals));
        insnList.add(ASMUtils.getReturnInst(returnType));
        return methodNode;
    }

    private static void addSetNotifyingFlagInstructions(InsnList insnList, String str) {
        insnList.add(new FieldInsnNode(178, CALLBACK_JVM_CLASS_NAME, "NOTIFYING_FLAG", "Ljava/lang/ThreadLocal;"));
        insnList.add(new FieldInsnNode(178, "java/lang/Boolean", str, "Ljava/lang/Boolean;"));
        insnList.add(new MethodInsnNode(182, "java/lang/ThreadLocal", "set", "(Ljava/lang/Object;)V", false));
    }

    private int computeInitialMaxLocals(String str) {
        int i = 0;
        for (Type type : Type.getArgumentTypes(str)) {
            i += type.getSize();
        }
        return i;
    }

    public boolean isCompleted() {
        return this.completed;
    }

    public static String getDynamicListenerVoidMethodDescriptor(DirectListener directListener) {
        return updateVoidDescriptor(null, directListener.getListenerMethod());
    }

    public static String getDynamicListenerMutatorMethodDescriptor(DirectListener directListener) {
        return updateMutableDescriptor(null, directListener.getListenerMethod());
    }

    private static String updateVoidDescriptor(MethodNode methodNode, Method method) {
        StringBuilder sb = new StringBuilder();
        sb.append("(");
        sb.append("I");
        Class<?>[] parameterTypes = method.getParameterTypes();
        for (int i = 0; i < parameterTypes.length; i++) {
            String descriptor = Type.getDescriptor(parameterTypes[i]);
            sb.append(descriptor);
            if (methodNode != null) {
                methodNode.localVariables.add(new LocalVariableNode("arg" + (i + 1), descriptor, (String) null, new LabelNode(), new LabelNode(), i + 1));
            }
        }
        sb.append(")V");
        String sb2 = sb.toString();
        if (methodNode != null) {
            methodNode.desc = sb2;
        }
        return sb2;
    }

    private static String updateMutableDescriptor(MethodNode methodNode, Method method) {
        StringBuilder sb = new StringBuilder();
        sb.append("(");
        sb.append("I");
        sb.append(Type.getReturnType(method).getDescriptor());
        Class<?>[] parameterTypes = method.getParameterTypes();
        for (int i = 0; i < parameterTypes.length; i++) {
            String descriptor = Type.getDescriptor(parameterTypes[i]);
            sb.append(descriptor);
            if (methodNode != null) {
                methodNode.localVariables.add(new LocalVariableNode("arg" + (i + 1), descriptor, (String) null, new LabelNode(), new LabelNode(), i + 1));
            }
        }
        sb.append(")");
        sb.append(Type.getReturnType(method).getDescriptor());
        String sb2 = sb.toString();
        if (methodNode != null) {
            methodNode.desc = sb2;
        }
        return sb2;
    }
}
