/*
 * Decompiled with CFR 0.152.
 */
package qilin.pta.toolkits.selectx;

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import qilin.core.PTA;
import qilin.core.builder.MethodNodeFactory;
import qilin.core.pag.AllocNode;
import qilin.core.pag.CallSite;
import qilin.core.pag.FieldRefNode;
import qilin.core.pag.GlobalVarNode;
import qilin.core.pag.LocalVarNode;
import qilin.core.pag.MethodPAG;
import qilin.core.pag.Node;
import qilin.core.pag.PAG;
import qilin.core.pag.Parm;
import qilin.pta.toolkits.selectx.BNode;
import qilin.pta.toolkits.selectx.G;
import qilin.pta.toolkits.selectx.I;
import qilin.pta.toolkits.selectx.L;
import qilin.pta.toolkits.selectx.O;
import soot.RefLikeType;
import soot.SootMethod;
import soot.Unit;
import soot.Value;
import soot.jimple.AssignStmt;
import soot.jimple.InstanceInvokeExpr;
import soot.jimple.InvokeExpr;
import soot.jimple.NullConstant;
import soot.jimple.Stmt;
import soot.jimple.spark.pag.SparkField;
import soot.jimple.toolkits.callgraph.Edge;
import soot.util.queue.QueueReader;

public class Selectx {
    private final PTA prePTA;
    private final PAG prePAG;
    private final Set<SparkField> sparkFields = new HashSet<SparkField>();
    Map<CallSite, Integer> call2Number = new HashMap<CallSite, Integer>();
    int totalCallsites = 0;

    public Selectx(PTA pta) {
        this.prePTA = pta;
        this.prePAG = pta.getPag();
        this.buildGraph();
    }

    public void addNewEdge(AllocNode from, LocalVarNode to) {
        O fromE = O.v(from);
        L toE = L.v(to, true);
        fromE.addOutEdge(toE);
        L toEI = L.v(to, false);
        toEI.addOutEdge(fromE);
    }

    public void addAssignEdge(LocalVarNode from, LocalVarNode to) {
        L fromE = L.v(from, true);
        L toE = L.v(to, true);
        fromE.addOutEdge(toE);
        L fromEI = L.v(from, false);
        L toEI = L.v(to, false);
        toEI.addOutEdge(fromEI);
    }

    public void addEntryEdge(LocalVarNode from, LocalVarNode to, CallSite callSite) {
        L toE;
        int i = this.getCallSiteNumber(callSite);
        L fromE = L.v(from, true);
        if (fromE.addOutEntryEdge(i, toE = L.v(to, true))) {
            toE.addInEntryEdge(i, fromE);
            L fromEI = L.v(from, false);
            L toEI = L.v(to, false);
            toEI.addOutExitEdge(i, fromEI);
        }
    }

    public void addExitEdge(LocalVarNode from, LocalVarNode to, CallSite callSite) {
        L toE;
        int i = this.getCallSiteNumber(callSite);
        L fromE = L.v(from, true);
        if (fromE.addOutExitEdge(i, toE = L.v(to, true))) {
            L fromEI = L.v(from, false);
            L toEI = L.v(to, false);
            toEI.addOutEntryEdge(i, fromEI);
            fromEI.addInEntryEdge(i, toEI);
        }
    }

    public void addStoreEdge(LocalVarNode from, LocalVarNode base) {
        L fromE = L.v(from, true);
        L baseE = L.v(base, true);
        L fromEI = L.v(from, false);
        L baseEI = L.v(base, false);
        fromE.addOutEdge(baseEI);
        baseE.addOutEdge(fromEI);
    }

    public void addStaticStoreEdge(LocalVarNode from, GlobalVarNode to) {
        L fromE = L.v(from, true);
        G toE = G.v(to, true);
        fromE.addOutEdge(toE);
        L fromEI = L.v(from, false);
        G toEI = G.v(to, false);
        toEI.addOutEdge(fromEI);
    }

    public void addStaticLoadEdge(GlobalVarNode from, LocalVarNode to) {
        G fromE = G.v(from, true);
        L toE = L.v(to, true);
        fromE.addOutEdge(toE);
        G fromEI = G.v(from, false);
        L toEI = L.v(to, false);
        toEI.addOutEdge(fromEI);
    }

    private void propagate(Set<BNode> workList, Set<I> paraWorkList) {
        while (!workList.isEmpty() || !paraWorkList.isEmpty()) {
            L l;
            BNode node;
            while (!workList.isEmpty()) {
                node = workList.iterator().next();
                workList.remove(node);
                node.forwardTargets().filter(BNode::setVisited).forEach(workList::add);
                if (!(node instanceof L)) continue;
                l = (L)node;
                l.getOutEntryEdges().stream().filter(tgt -> tgt.paras.add(tgt)).forEach(paraWorkList::add);
            }
            while (!paraWorkList.isEmpty()) {
                node = paraWorkList.iterator().next();
                paraWorkList.remove(node);
                ((I)node).getOutTargets().stream().filter(arg_0 -> Selectx.lambda$propagate$1((I)node, arg_0)).forEach(paraWorkList::add);
                if (!(node instanceof L)) continue;
                l = (L)node;
                l.getOutGs().filter(BNode::setVisited).forEach(workList::add);
                l.getOutEntryEdges().stream().filter(tgt -> tgt.paras.add(tgt)).forEach(paraWorkList::add);
                for (Map.Entry<Integer, Set<L>> entry : l.getOutExitEdges()) {
                    Integer i = entry.getKey();
                    Set<L> tgts = entry.getValue();
                    l.paras.stream().flatMap(para -> para.getInEntryEdges(i).stream()).forEach(arg -> tgts.forEach(tgt -> {
                        if (arg.addOutEdge((I)tgt)) {
                            if (arg.isVisited() && tgt.setVisited()) {
                                workList.add((BNode)tgt);
                            }
                            if (tgt.update((I)arg)) {
                                paraWorkList.add((I)tgt);
                            }
                        }
                    }));
                }
            }
        }
    }

    private void resetNodes() {
        G.g2GN.values().forEach(BNode::reset);
        G.g2GP.values().forEach(BNode::reset);
        L.l2LN.values().forEach(BNode::reset);
        L.l2LP.values().forEach(BNode::reset);
        O.o2O.values().forEach(BNode::reset);
        L.l2LN.values().forEach(I::clearParas);
        L.l2LP.values().forEach(I::clearParas);
        O.o2O.values().forEach(I::clearParas);
    }

    public Map<Object, Integer> process() {
        System.out.print("cs2 propogating ...");
        long time = System.currentTimeMillis();
        HashSet<BNode> workList = new HashSet<BNode>();
        HashSet<I> paraWorkList = new HashSet<I>();
        O.o2O.values().forEach(o -> {
            o.setVisited();
            workList.add((BNode)o);
        });
        this.propagate(workList, paraWorkList);
        Set<O> entryO = O.o2O.values().stream().filter(o -> !o.paras.isEmpty()).collect(Collectors.toSet());
        Set<L> entryL = Stream.concat(L.l2LP.values().stream(), L.l2LN.values().stream()).filter(l -> !l.paras.isEmpty()).collect(Collectors.toSet());
        this.resetNodes();
        L.l2LN.values().forEach(ln -> {
            ln.setVisited();
            workList.add((BNode)ln);
        });
        this.propagate(workList, paraWorkList);
        System.out.println((System.currentTimeMillis() - time) / 1000L + "s");
        HashMap<Object, Integer> ret = new HashMap<Object, Integer>();
        entryO.forEach(o -> {
            if (!o.paras.isEmpty()) {
                ret.put(o.sparkNode, 1);
            } else {
                ret.put(o.sparkNode, 0);
            }
        });
        entryL.forEach(l -> {
            if (!l.inv().paras.isEmpty()) {
                ret.put(l.sparkNode, 1);
            } else {
                ret.put(l.sparkNode, 0);
            }
        });
        this.sparkFields.forEach(f -> ret.put(f, 1));
        return ret;
    }

    int getCallSiteNumber(CallSite callsite) {
        Integer oldNumber = this.call2Number.get(callsite);
        if (oldNumber != null) {
            return oldNumber;
        }
        ++this.totalCallsites;
        this.call2Number.put(callsite, this.totalCallsites);
        return this.totalCallsites;
    }

    private void buildGraph() {
        for (SootMethod method : this.prePTA.getNakedReachableMethods()) {
            if (method.isPhantom()) continue;
            MethodPAG srcmpag = this.prePAG.getMethodPAG(method);
            Object reader = srcmpag.getInternalReader().clone();
            while (((QueueReader)reader).hasNext()) {
                Node from = (Node)((QueueReader)reader).next();
                Node to = (Node)((QueueReader)reader).next();
                if (from instanceof LocalVarNode) {
                    if (to instanceof LocalVarNode) {
                        this.addAssignEdge((LocalVarNode)from, (LocalVarNode)to);
                        continue;
                    }
                    if (to instanceof FieldRefNode) {
                        FieldRefNode fr = (FieldRefNode)to;
                        this.addStoreEdge((LocalVarNode)from, (LocalVarNode)fr.getBase());
                        this.sparkFields.add(fr.getField());
                        continue;
                    }
                    assert (to instanceof GlobalVarNode);
                    this.addStaticStoreEdge((LocalVarNode)from, (GlobalVarNode)to);
                    continue;
                }
                if (from instanceof AllocNode) {
                    if (!(to instanceof LocalVarNode)) continue;
                    this.addNewEdge((AllocNode)from, (LocalVarNode)to);
                    continue;
                }
                if (from instanceof FieldRefNode) {
                    FieldRefNode fr = (FieldRefNode)from;
                    this.addAssignEdge((LocalVarNode)fr.getBase(), (LocalVarNode)to);
                    this.sparkFields.add(fr.getField());
                    continue;
                }
                assert (from instanceof GlobalVarNode);
                this.addStaticLoadEdge((GlobalVarNode)from, (LocalVarNode)to);
            }
            srcmpag.getExceptionEdges().forEach((k, vs) -> {
                for (Node v : vs) {
                    this.addAssignEdge((LocalVarNode)k, (LocalVarNode)v);
                }
            });
            MethodNodeFactory srcnf = srcmpag.nodeFactory();
            for (Unit u : srcmpag.getInvokeStmts()) {
                Value dest;
                Stmt s2 = (Stmt)u;
                CallSite callSite = new CallSite(u);
                InvokeExpr ie = s2.getInvokeExpr();
                int numArgs = ie.getArgCount();
                Value[] args = new Value[numArgs];
                for (int i = 0; i < numArgs; ++i) {
                    Value arg = ie.getArg(i);
                    if (!(arg.getType() instanceof RefLikeType) || arg instanceof NullConstant) continue;
                    args[i] = arg;
                }
                LocalVarNode retDest = null;
                if (s2 instanceof AssignStmt && (dest = ((AssignStmt)s2).getLeftOp()).getType() instanceof RefLikeType) {
                    retDest = this.prePAG.findLocalVarNode(dest);
                }
                LocalVarNode receiver = null;
                if (ie instanceof InstanceInvokeExpr) {
                    InstanceInvokeExpr iie = (InstanceInvokeExpr)ie;
                    receiver = this.prePAG.findLocalVarNode(iie.getBase());
                }
                Iterator<Edge> it = this.prePTA.getCallGraph().edgesOutOf(u);
                while (it.hasNext()) {
                    Edge e = it.next();
                    SootMethod tgtmtd = e.tgt();
                    MethodPAG tgtmpag = this.prePAG.getMethodPAG(tgtmtd);
                    MethodNodeFactory tgtnf = tgtmpag.nodeFactory();
                    for (int i = 0; i < numArgs; ++i) {
                        if (args[i] == null || !(tgtmtd.getParameterType(i) instanceof RefLikeType)) continue;
                        LocalVarNode parm = (LocalVarNode)tgtnf.caseParm(i);
                        this.addEntryEdge((LocalVarNode)srcnf.getNode(args[i]), parm, callSite);
                    }
                    if (retDest != null && tgtmtd.getReturnType() instanceof RefLikeType) {
                        LocalVarNode ret = (LocalVarNode)tgtnf.caseRet();
                        this.addExitEdge(ret, retDest, callSite);
                    }
                    LocalVarNode stmtThrowNode = srcnf.makeInvokeStmtThrowVarNode(s2, method);
                    LocalVarNode throwFinal = this.prePAG.findLocalVarNode(new Parm(tgtmtd, -3));
                    if (throwFinal != null) {
                        this.addExitEdge(throwFinal, stmtThrowNode, callSite);
                    }
                    if (receiver == null) continue;
                    LocalVarNode thisRef = (LocalVarNode)tgtnf.caseThis();
                    this.addEntryEdge(receiver, thisRef, callSite);
                }
            }
        }
    }

    private static /* synthetic */ boolean lambda$propagate$1(I node, I i) {
        return i.update(node);
    }
}

