/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops.rewrite;

import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.UnaryOp;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.rewrite.HopRewriteRule;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.hops.rewrite.ProgramRewriteStatus;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.Program;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContextFactory;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;

public class RewriteConstantFolding
extends HopRewriteRule {
    private static final String TMP_VARNAME = "__cf_tmp";
    private BasicProgramBlock _tmpPB = null;
    private ExecutionContext _tmpEC = null;

    @Override
    public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) {
        if (roots == null) {
            return null;
        }
        for (int i = 0; i < roots.size(); ++i) {
            Hop h = roots.get(i);
            roots.set(i, this.rule_ConstantFolding(h));
        }
        return roots;
    }

    @Override
    public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) {
        if (root == null) {
            return null;
        }
        return this.rule_ConstantFolding(root);
    }

    private Hop rule_ConstantFolding(Hop hop) {
        return this.rConstantFoldingExpression(hop);
    }

    private Hop rConstantFoldingExpression(Hop root) {
        if (root.isVisited()) {
            return root;
        }
        for (int i = 0; i < root.getInput().size(); ++i) {
            Hop h = root.getInput().get(i);
            this.rConstantFoldingExpression(h);
        }
        LiteralOp literal = null;
        if (root.getDataType() == Types.DataType.SCALAR && (RewriteConstantFolding.isApplicableUnaryOp(root) || RewriteConstantFolding.isApplicableBinaryOp(root) || RewriteConstantFolding.isApplicableTernaryOp(root) || RewriteConstantFolding.isApplicableNaryOp(root))) {
            literal = this.evalScalarOperation(root);
        } else if (RewriteConstantFolding.isApplicableFalseConjunctivePredicate(root)) {
            literal = new LiteralOp(false);
        } else if (RewriteConstantFolding.isApplicableTrueDisjunctivePredicate(root)) {
            literal = new LiteralOp(true);
        }
        if (literal != null) {
            if (!root.getParent().isEmpty()) {
                ArrayList<Hop> parents = new ArrayList<Hop>(root.getParent());
                for (Hop parent : parents) {
                    HopRewriteUtils.replaceChildReference(parent, root, literal);
                }
            } else {
                root = literal;
            }
        }
        root.setVisited();
        return root;
    }

    private LiteralOp evalScalarOperation(Hop bop) {
        DataOp tmpWrite = new DataOp(TMP_VARNAME, bop.getDataType(), bop.getValueType(), bop, Types.OpOpData.TRANSIENTWRITE, TMP_VARNAME);
        Dag<Lop> dag = new Dag<Lop>();
        Recompiler.rClearLops(tmpWrite);
        Lop lops = tmpWrite.constructLops();
        lops.addToDag(dag);
        ArrayList<Instruction> inst = dag.getJobs(null, ConfigurationManager.getDMLConfig());
        ExecutionContext ec = this.getExecutionContext();
        BasicProgramBlock pb = this.getProgramBlock();
        pb.setInstructions(inst);
        pb.execute(ec);
        ScalarObject so = (ScalarObject)ec.getVariable(TMP_VARNAME);
        LiteralOp literal = ScalarObjectFactory.createLiteralOp(so);
        tmpWrite.getInput().clear();
        bop.getParent().remove(tmpWrite);
        pb.setInstructions(null);
        ec.getVariables().removeAll();
        HopRewriteUtils.setOutputParametersForScalar(literal);
        return literal;
    }

    private BasicProgramBlock getProgramBlock() {
        if (this._tmpPB == null) {
            this._tmpPB = new BasicProgramBlock(new Program());
        }
        return this._tmpPB;
    }

    private ExecutionContext getExecutionContext() {
        if (this._tmpEC == null) {
            this._tmpEC = ExecutionContextFactory.createContext();
        }
        return this._tmpEC;
    }

    private static boolean isApplicableBinaryOp(Hop hop) {
        List<Hop> in = hop.getInput();
        return hop instanceof BinaryOp && in.get(0) instanceof LiteralOp && in.get(1) instanceof LiteralOp && ((BinaryOp)hop).getOp() != Types.OpOp2.CBIND && ((BinaryOp)hop).getOp() != Types.OpOp2.RBIND;
    }

    private static boolean isApplicableUnaryOp(Hop hop) {
        List<Hop> in = hop.getInput();
        return hop instanceof UnaryOp && in.get(0) instanceof LiteralOp && ((UnaryOp)hop).getOp() != Types.OpOp1.EXISTS && ((UnaryOp)hop).getOp() != Types.OpOp1.PRINT && ((UnaryOp)hop).getOp() != Types.OpOp1.ASSERT && ((UnaryOp)hop).getOp() != Types.OpOp1.STOP && hop.getDataType() == Types.DataType.SCALAR;
    }

    private static boolean isApplicableTernaryOp(Hop hop) {
        return HopRewriteUtils.isTernary(hop, Types.OpOp3.IFELSE, Types.OpOp3.MINUS_MULT, Types.OpOp3.PLUS_MULT) && hop.getInput().stream().allMatch(h -> h instanceof LiteralOp);
    }

    private static boolean isApplicableNaryOp(Hop hop) {
        return HopRewriteUtils.isNary(hop, Types.OpOpN.MIN, Types.OpOpN.MAX, Types.OpOpN.PLUS) && hop.getInput().stream().allMatch(h -> h instanceof LiteralOp);
    }

    private static boolean isApplicableFalseConjunctivePredicate(Hop hop) {
        List<Hop> in = hop.getInput();
        return HopRewriteUtils.isBinary(hop, Types.OpOp2.AND) && hop.getDataType().isScalar() && (in.get(0) instanceof LiteralOp && !((LiteralOp)in.get(0)).getBooleanValue() || in.get(1) instanceof LiteralOp && !((LiteralOp)in.get(1)).getBooleanValue());
    }

    private static boolean isApplicableTrueDisjunctivePredicate(Hop hop) {
        List<Hop> in = hop.getInput();
        return HopRewriteUtils.isBinary(hop, Types.OpOp2.OR) && hop.getDataType().isScalar() && (in.get(0) instanceof LiteralOp && ((LiteralOp)in.get(0)).getBooleanValue() || in.get(1) instanceof LiteralOp && ((LiteralOp)in.get(1)).getBooleanValue());
    }
}

