/*
 * Decompiled with CFR 0.152.
 */
package qilin.core.reflection;

import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import qilin.CoreConfig;
import qilin.core.PTAScene;
import qilin.core.reflection.ReflectionKind;
import qilin.core.reflection.ReflectionModel;
import qilin.util.DataFactory;
import qilin.util.PTAUtils;
import soot.ArrayType;
import soot.Body;
import soot.Local;
import soot.RefLikeType;
import soot.RefType;
import soot.SootClass;
import soot.SootField;
import soot.SootMethod;
import soot.Unit;
import soot.Value;
import soot.jimple.AssignStmt;
import soot.jimple.ClassConstant;
import soot.jimple.ConvertToBaf;
import soot.jimple.IntConstant;
import soot.jimple.InvokeExpr;
import soot.jimple.Jimple;
import soot.jimple.NullConstant;
import soot.jimple.Stmt;
import soot.jimple.internal.AbstractInvokeExpr;
import soot.jimple.internal.JArrayRef;
import soot.jimple.internal.JAssignStmt;
import soot.jimple.internal.JInstanceFieldRef;
import soot.jimple.internal.JInvokeStmt;
import soot.jimple.internal.JNewArrayExpr;
import soot.jimple.internal.JNewExpr;
import soot.jimple.internal.JSpecialInvokeExpr;
import soot.jimple.internal.JStaticInvokeExpr;
import soot.jimple.internal.JVirtualInvokeExpr;
import soot.jimple.internal.JimpleLocal;
import soot.tagkit.LineNumberTag;

public class TamiflexModel
extends ReflectionModel {
    protected Map<ReflectionKind, Map<Stmt, Set<String>>> reflectionMap = DataFactory.createMap();

    public TamiflexModel() {
        this.parseTamiflexLog(CoreConfig.v().getAppConfig().REFLECTION_LOG, true);
    }

    @Override
    Collection<Unit> transformClassForName(Stmt s2) {
        Set<Unit> ret = DataFactory.createSet();
        Map classForNames = this.reflectionMap.getOrDefault((Object)ReflectionKind.ClassForName, Collections.emptyMap());
        if (classForNames.containsKey(s2)) {
            Collection fornames = (Collection)classForNames.get(s2);
            for (String clazz : fornames) {
                RefType refType = RefType.v(clazz);
                ClassConstant cc = ClassConstant.fromType(refType);
                if (!(s2 instanceof AssignStmt)) continue;
                Value lvalue = ((AssignStmt)s2).getLeftOp();
                ret.add(new JAssignStmt(lvalue, cc));
            }
        }
        return ret;
    }

    @Override
    protected Collection<Unit> transformClassNewInstance(Stmt s2) {
        if (!(s2 instanceof AssignStmt)) {
            return Collections.emptySet();
        }
        Value lvalue = ((AssignStmt)s2).getLeftOp();
        Set<Unit> ret = DataFactory.createSet();
        Map classNewInstances = this.reflectionMap.getOrDefault((Object)ReflectionKind.ClassNewInstance, Collections.emptyMap());
        if (classNewInstances.containsKey(s2)) {
            Collection classNames = (Collection)classNewInstances.get(s2);
            for (String clsName : classNames) {
                SootClass cls = PTAScene.v().getSootClass(clsName);
                if (!cls.declaresMethod(PTAScene.v().getSubSigNumberer().findOrAdd("void <init>()"))) continue;
                JNewExpr newExpr = new JNewExpr(cls.getType());
                ret.add(new JAssignStmt(lvalue, newExpr));
                SootMethod constructor = cls.getMethod(PTAScene.v().getSubSigNumberer().findOrAdd("void <init>()"));
                ret.add(new JInvokeStmt(new JSpecialInvokeExpr((Local)lvalue, constructor.makeRef(), Collections.emptyList())));
            }
        }
        return ret;
    }

    @Override
    protected Collection<Unit> transformContructorNewInstance(Stmt s2) {
        if (!(s2 instanceof AssignStmt)) {
            return Collections.emptySet();
        }
        Value lvalue = ((AssignStmt)s2).getLeftOp();
        Set<Unit> ret = DataFactory.createSet();
        Map constructorNewInstances = this.reflectionMap.getOrDefault((Object)ReflectionKind.ConstructorNewInstance, Collections.emptyMap());
        if (constructorNewInstances.containsKey(s2)) {
            Collection constructorSignatures = (Collection)constructorNewInstances.get(s2);
            InvokeExpr iie = s2.getInvokeExpr();
            Value args = iie.getArg(0);
            JArrayRef arrayRef = new JArrayRef(args, IntConstant.v(0));
            JimpleLocal arg = new JimpleLocal("intermediate/" + arrayRef, RefType.v("java.lang.Object"));
            ret.add(new JAssignStmt(arg, arrayRef));
            for (String constructorSignature : constructorSignatures) {
                SootMethod constructor = PTAScene.v().getMethod(constructorSignature);
                SootClass cls = constructor.getDeclaringClass();
                JNewExpr newExpr = new JNewExpr(cls.getType());
                ret.add(new JAssignStmt(lvalue, newExpr));
                int argCount = constructor.getParameterCount();
                ArrayList<JimpleLocal> mArgs = new ArrayList<JimpleLocal>(argCount);
                for (int i = 0; i < argCount; ++i) {
                    mArgs.add(arg);
                }
                ret.add(new JInvokeStmt(new JSpecialInvokeExpr((Local)lvalue, constructor.makeRef(), mArgs)));
            }
        }
        return ret;
    }

    @Override
    protected Collection<Unit> transformMethodInvoke(Stmt s2) {
        Set<Unit> ret = DataFactory.createSet();
        Map methodInvokes = this.reflectionMap.getOrDefault((Object)ReflectionKind.MethodInvoke, Collections.emptyMap());
        if (methodInvokes.containsKey(s2)) {
            Collection methodSignatures = (Collection)methodInvokes.get(s2);
            InvokeExpr iie = s2.getInvokeExpr();
            Value base = iie.getArg(0);
            Value args = iie.getArg(1);
            JimpleLocal arg = null;
            if (args.getType() instanceof ArrayType) {
                JArrayRef arrayRef = new JArrayRef(args, IntConstant.v(0));
                arg = new JimpleLocal("intermediate/" + arrayRef, RefType.v("java.lang.Object"));
                ret.add(new JAssignStmt(arg, arrayRef));
            }
            for (String methodSignature : methodSignatures) {
                AbstractInvokeExpr ie;
                SootMethod method = PTAScene.v().getMethod(methodSignature);
                int argCount = method.getParameterCount();
                ArrayList<JimpleLocal> mArgs = new ArrayList<JimpleLocal>(argCount);
                for (int i = 0; i < argCount; ++i) {
                    mArgs.add(arg);
                }
                if (method.isStatic()) {
                    assert (base instanceof NullConstant);
                    ie = new JStaticInvokeExpr(method.makeRef(), (List<? extends Value>)mArgs);
                } else {
                    assert (!(base instanceof NullConstant));
                    ie = new JVirtualInvokeExpr(base, method.makeRef(), mArgs);
                }
                if (s2 instanceof AssignStmt) {
                    Value lvalue = ((AssignStmt)s2).getLeftOp();
                    ret.add(new JAssignStmt(lvalue, ie));
                    continue;
                }
                ret.add(new JInvokeStmt(ie));
            }
        }
        return ret;
    }

    @Override
    protected Collection<Unit> transformFieldSet(Stmt s2) {
        Set<Unit> ret = DataFactory.createSet();
        Map fieldSets = this.reflectionMap.getOrDefault((Object)ReflectionKind.FieldSet, Collections.emptyMap());
        if (fieldSets.containsKey(s2)) {
            Collection fieldSignatures = (Collection)fieldSets.get(s2);
            InvokeExpr iie = s2.getInvokeExpr();
            Value base = iie.getArg(0);
            Value rValue = iie.getArg(1);
            for (String fieldSignature : fieldSignatures) {
                ConvertToBaf fieldRef;
                SootField field = PTAScene.v().getField(fieldSignature).makeRef().resolve();
                if (field.isStatic()) {
                    assert (base instanceof NullConstant);
                    fieldRef = Jimple.v().newStaticFieldRef(field.makeRef());
                } else {
                    assert (!(base instanceof NullConstant));
                    fieldRef = new JInstanceFieldRef(base, field.makeRef());
                }
                JAssignStmt stmt = new JAssignStmt((Value)((Object)fieldRef), rValue);
                ret.add(stmt);
            }
        }
        return ret;
    }

    @Override
    protected Collection<Unit> transformFieldGet(Stmt s2) {
        Set<Unit> ret = DataFactory.createSet();
        Map fieldGets = this.reflectionMap.getOrDefault((Object)ReflectionKind.FieldGet, Collections.emptyMap());
        if (fieldGets.containsKey(s2) && s2 instanceof AssignStmt) {
            Collection fieldSignatures = (Collection)fieldGets.get(s2);
            Value lvalue = ((AssignStmt)s2).getLeftOp();
            InvokeExpr iie = s2.getInvokeExpr();
            Value base = iie.getArg(0);
            for (String fieldSignature : fieldSignatures) {
                ConvertToBaf fieldRef;
                SootField field = PTAScene.v().getField(fieldSignature).makeRef().resolve();
                if (field.isStatic()) {
                    assert (base instanceof NullConstant);
                    fieldRef = Jimple.v().newStaticFieldRef(field.makeRef());
                } else {
                    assert (!(base instanceof NullConstant));
                    fieldRef = new JInstanceFieldRef(base, field.makeRef());
                }
                if (!(fieldRef.getType() instanceof RefLikeType)) continue;
                JAssignStmt stmt = new JAssignStmt(lvalue, (Value)((Object)fieldRef));
                ret.add(stmt);
            }
        }
        return ret;
    }

    @Override
    protected Collection<Unit> transformArrayNewInstance(Stmt s2) {
        Set<Unit> ret = DataFactory.createSet();
        Map mappedToArrayTypes = this.reflectionMap.getOrDefault((Object)ReflectionKind.ArrayNewInstance, Collections.emptyMap());
        Collection arrayTypes = mappedToArrayTypes.getOrDefault(s2, Collections.emptySet());
        for (String arrayType : arrayTypes) {
            ArrayType at = (ArrayType)PTAScene.v().getTypeUnsafe(arrayType, true);
            JNewArrayExpr newExpr = new JNewArrayExpr(at.getElementType(), IntConstant.v(1));
            if (!(s2 instanceof AssignStmt)) continue;
            Value lvalue = ((AssignStmt)s2).getLeftOp();
            ret.add(new JAssignStmt(lvalue, newExpr));
        }
        return ret;
    }

    @Override
    Collection<Unit> transformArrayGet(Stmt s2) {
        Set<Unit> ret = DataFactory.createSet();
        InvokeExpr iie = s2.getInvokeExpr();
        Value base = iie.getArg(0);
        if (s2 instanceof AssignStmt) {
            Value lvalue = ((AssignStmt)s2).getLeftOp();
            JArrayRef arrayRef = null;
            if (base.getType() instanceof ArrayType) {
                arrayRef = new JArrayRef(base, IntConstant.v(0));
            } else if (base.getType() == RefType.v("java.lang.Object")) {
                JimpleLocal local = new JimpleLocal("intermediate/" + base, ArrayType.v(RefType.v("java.lang.Object"), 1));
                ret.add(new JAssignStmt(local, base));
                arrayRef = new JArrayRef(local, IntConstant.v(0));
            }
            if (arrayRef != null) {
                ret.add(new JAssignStmt(lvalue, arrayRef));
            }
        }
        return ret;
    }

    @Override
    Collection<Unit> transformArraySet(Stmt s2) {
        Set<Unit> ret = DataFactory.createSet();
        InvokeExpr iie = s2.getInvokeExpr();
        Value base = iie.getArg(0);
        if (base.getType() instanceof ArrayType) {
            Value from = iie.getArg(2);
            JArrayRef arrayRef = new JArrayRef(base, IntConstant.v(0));
            ret.add(new JAssignStmt(arrayRef, from));
        }
        return ret;
    }

    private void parseTamiflexLog(String logFile, boolean verbose) {
        try {
            String line;
            BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream(logFile)));
            block9: while ((line = reader.readLine()) != null) {
                int lineNumber;
                String[] portions = line.split(";", -1);
                if (portions.length < 4) {
                    if (!verbose) continue;
                    System.out.println("Warning: illegal tamiflex log: " + line);
                    continue;
                }
                ReflectionKind kind = ReflectionKind.parse(portions[0]);
                String mappedTarget = portions[1];
                String inClzDotMthdStr = portions[2];
                int n = lineNumber = portions[3].length() == 0 ? -1 : Integer.parseInt(portions[3]);
                if (kind == null) {
                    if (!verbose) continue;
                    System.out.println("Warning: illegal tamiflex reflection kind: " + portions[0]);
                    continue;
                }
                switch (kind) {
                    case ClassForName: {
                        break;
                    }
                    case ClassNewInstance: {
                        if (PTAScene.v().containsClass(mappedTarget)) break;
                        if (!verbose) continue block9;
                        System.out.println("Warning: Unknown mapped class for signature: " + mappedTarget);
                        continue block9;
                    }
                    case ConstructorNewInstance: 
                    case MethodInvoke: {
                        if (PTAScene.v().containsMethod(mappedTarget)) break;
                        if (!verbose) continue block9;
                        System.out.println("Warning: Unknown mapped method for signature: " + mappedTarget);
                        continue block9;
                    }
                    case FieldSet: 
                    case FieldGet: {
                        if (PTAScene.v().containsField(mappedTarget)) break;
                        if (!verbose) continue block9;
                        System.out.println("Warning: Unknown mapped field for signature: " + mappedTarget);
                        continue block9;
                    }
                    case ArrayNewInstance: {
                        break;
                    }
                    default: {
                        if (!verbose) break;
                        System.out.println("Warning: Unsupported reflection kind: " + kind);
                    }
                }
                Collection<Stmt> possibleSourceStmts = this.inferSourceStmt(inClzDotMthdStr, kind, lineNumber);
                for (Stmt stmt : possibleSourceStmts) {
                    this.reflectionMap.computeIfAbsent(kind, m4 -> DataFactory.createMap()).computeIfAbsent(stmt, k -> DataFactory.createSet()).add(mappedTarget);
                }
            }
            reader.close();
        }
        catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    private Collection<SootMethod> inferSourceMethod(String inClzDotMthd) {
        String inClassStr = inClzDotMthd.substring(0, inClzDotMthd.lastIndexOf("."));
        String inMethodStr = inClzDotMthd.substring(inClzDotMthd.lastIndexOf(".") + 1);
        if (!PTAScene.v().containsClass(inClassStr)) {
            System.out.println("Warning: unknown class \"" + inClassStr + "\" is referenced.");
        }
        SootClass sootClass = PTAScene.v().getSootClass(inClassStr);
        Set<SootMethod> ret = DataFactory.createSet();
        for (SootMethod m4 : sootClass.getMethods()) {
            if (!m4.isConcrete() || !m4.getName().equals(inMethodStr)) continue;
            ret.add(m4);
        }
        return ret;
    }

    private Collection<Stmt> inferSourceStmt(String inClzDotMthd, ReflectionKind kind, int lineNumber) {
        Set<Stmt> ret = DataFactory.createSet();
        Set<Stmt> potential = DataFactory.createSet();
        Collection<SootMethod> sourceMethods = this.inferSourceMethod(inClzDotMthd);
        for (SootMethod sm : sourceMethods) {
            Body body = PTAUtils.getMethodBody(sm);
            for (Unit u : body.getUnits()) {
                String methodSig;
                Stmt stmt;
                if (!(u instanceof Stmt) || !(stmt = (Stmt)u).containsInvokeExpr() || !this.matchReflectionKind(kind, methodSig = stmt.getInvokeExpr().getMethodRef().getSignature())) continue;
                potential.add(stmt);
            }
        }
        for (Stmt stmt : potential) {
            LineNumberTag tag = (LineNumberTag)stmt.getTag("LineNumberTag");
            if (lineNumber >= 0 && (tag == null || tag.getLineNumber() != lineNumber)) continue;
            ret.add(stmt);
        }
        if (ret.size() == 0 && potential.size() > 0) {
            System.out.print("Warning: Mismatch between statement and reflection log entry - ");
            System.out.println(kind + ";" + inClzDotMthd + ";" + lineNumber + ";");
            return potential;
        }
        return ret;
    }

    private boolean matchReflectionKind(ReflectionKind kind, String methodSig) {
        return switch (kind) {
            case ReflectionKind.ClassForName -> {
                if (methodSig.equals("<java.lang.Class: java.lang.Class forName(java.lang.String)>") || methodSig.equals("<java.lang.Class: java.lang.Class forName(java.lang.String,boolean,java.lang.ClassLoader)>")) {
                    yield true;
                }
                yield false;
            }
            case ReflectionKind.ClassNewInstance -> methodSig.equals("<java.lang.Class: java.lang.Object newInstance()>");
            case ReflectionKind.ConstructorNewInstance -> methodSig.equals("<java.lang.reflect.Constructor: java.lang.Object newInstance(java.lang.Object[])>");
            case ReflectionKind.MethodInvoke -> methodSig.equals("<java.lang.reflect.Method: java.lang.Object invoke(java.lang.Object,java.lang.Object[])>");
            case ReflectionKind.FieldSet -> methodSig.equals("<java.lang.reflect.Field: void set(java.lang.Object,java.lang.Object)>");
            case ReflectionKind.FieldGet -> methodSig.equals("<java.lang.reflect.Field: java.lang.Object get(java.lang.Object)>");
            case ReflectionKind.ArrayNewInstance -> methodSig.equals("<java.lang.reflect.Array: java.lang.Object newInstance(java.lang.Class,int)>");
            default -> false;
        };
    }
}

