/*
 * Decompiled with CFR 0.152.
 */
package org.apache.asterix.optimizer.rules;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.asterix.dataflow.data.common.TypeResolverUtil;
import org.apache.asterix.lang.common.util.FunctionUtil;
import org.apache.asterix.om.functions.BuiltinFunctions;
import org.apache.asterix.om.typecomputer.base.TypeCastUtils;
import org.apache.asterix.om.typecomputer.impl.TypeComputeUtils;
import org.apache.asterix.om.types.IAType;
import org.apache.commons.lang3.mutable.Mutable;
import org.apache.commons.lang3.mutable.MutableObject;
import org.apache.hyracks.algebricks.common.exceptions.AlgebricksException;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalExpression;
import org.apache.hyracks.algebricks.core.algebra.base.ILogicalOperator;
import org.apache.hyracks.algebricks.core.algebra.base.IOptimizationContext;
import org.apache.hyracks.algebricks.core.algebra.base.LogicalExpressionTag;
import org.apache.hyracks.algebricks.core.algebra.expressions.AbstractFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.expressions.IVariableTypeEnvironment;
import org.apache.hyracks.algebricks.core.algebra.expressions.ScalarFunctionCallExpression;
import org.apache.hyracks.algebricks.core.algebra.functions.FunctionIdentifier;
import org.apache.hyracks.algebricks.core.algebra.functions.IFunctionInfo;
import org.apache.hyracks.algebricks.core.algebra.typing.ITypingContext;
import org.apache.hyracks.algebricks.core.rewriter.base.IAlgebraicRewriteRule;

public class InjectTypeCastForFunctionArgumentsRule
implements IAlgebraicRewriteRule {
    private static final Map<FunctionIdentifier, BiIntPredicate> FUN_TO_ARG_CHECKER = new HashMap<FunctionIdentifier, BiIntPredicate>();

    public static void addFunctionAndArgChecker(FunctionIdentifier function, BiIntPredicate argChecker) {
        FUN_TO_ARG_CHECKER.put(function, argChecker);
    }

    public boolean rewritePost(Mutable<ILogicalOperator> opRef, IOptimizationContext context) throws AlgebricksException {
        ILogicalOperator op = (ILogicalOperator)opRef.getValue();
        if (op.getInputs().isEmpty()) {
            return false;
        }
        context.computeAndSetTypeEnvironmentForOperator(op);
        if (op.acceptExpressionTransform(exprRef -> InjectTypeCastForFunctionArgumentsRule.injectTypeCast(op, (Mutable<ILogicalExpression>)exprRef, context))) {
            context.computeAndSetTypeEnvironmentForOperator(op);
            return true;
        }
        return false;
    }

    private static boolean injectTypeCast(ILogicalOperator op, Mutable<ILogicalExpression> exprRef, IOptimizationContext context) throws AlgebricksException {
        ILogicalExpression expr = (ILogicalExpression)exprRef.getValue();
        if (expr.getExpressionTag() != LogicalExpressionTag.FUNCTION_CALL) {
            return false;
        }
        boolean rewritten = false;
        AbstractFunctionCallExpression func = (AbstractFunctionCallExpression)expr;
        for (Mutable argRef : func.getArguments()) {
            if (!InjectTypeCastForFunctionArgumentsRule.injectTypeCast(op, (Mutable<ILogicalExpression>)argRef, context)) continue;
            context.computeAndSetTypeEnvironmentForOperator(op);
            rewritten = true;
        }
        FunctionIdentifier funcId = func.getFunctionIdentifier();
        if (FUN_TO_ARG_CHECKER.containsKey(funcId)) {
            rewritten |= InjectTypeCastForFunctionArgumentsRule.rewriteFunction(op, func, FUN_TO_ARG_CHECKER.get(funcId), context);
        }
        return rewritten;
    }

    private static boolean rewriteFunction(ILogicalOperator op, AbstractFunctionCallExpression func, BiIntPredicate argChecker, IOptimizationContext context) throws AlgebricksException {
        IVariableTypeEnvironment env = op.computeInputTypeEnvironment((ITypingContext)context);
        IAType producedType = (IAType)env.getType((ILogicalExpression)func);
        List argRefs = func.getArguments();
        int argSize = argRefs.size();
        boolean rewritten = false;
        for (int argIndex = 0; argIndex < argSize; ++argIndex) {
            if (argChecker != null && !argChecker.test(argIndex, argSize)) continue;
            rewritten |= InjectTypeCastForFunctionArgumentsRule.rewriteFunctionArgument((Mutable<ILogicalExpression>)((Mutable)argRefs.get(argIndex)), producedType, env);
        }
        return rewritten;
    }

    private static boolean rewriteFunctionArgument(Mutable<ILogicalExpression> argRef, IAType funcOutputType, IVariableTypeEnvironment env) throws AlgebricksException {
        ILogicalExpression argExpr = (ILogicalExpression)argRef.getValue();
        IAType type = (IAType)env.getType(argExpr);
        if (TypeResolverUtil.needsCast((IAType)funcOutputType, (IAType)type)) {
            ScalarFunctionCallExpression castFunc = new ScalarFunctionCallExpression((IFunctionInfo)FunctionUtil.getFunctionInfo((FunctionIdentifier)BuiltinFunctions.CAST_TYPE), new ArrayList<MutableObject>(Collections.singletonList(new MutableObject((Object)argExpr))));
            castFunc.setSourceLocation(argExpr.getSourceLocation());
            IAType funcOutputPrimeType = TypeComputeUtils.getActualType((IAType)funcOutputType);
            TypeCastUtils.setRequiredAndInputTypes((AbstractFunctionCallExpression)castFunc, (IAType)funcOutputPrimeType, (IAType)type, (boolean)false);
            argRef.setValue((Object)castFunc);
            return true;
        }
        return false;
    }

    public static boolean switchResultArg(int argIdx, int numArguments) {
        return argIdx > 1 && (argIdx % 2 == 0 || argIdx == numArguments - 1);
    }

    static {
        InjectTypeCastForFunctionArgumentsRule.addFunctionAndArgChecker(BuiltinFunctions.SWITCH_CASE, InjectTypeCastForFunctionArgumentsRule::switchResultArg);
        InjectTypeCastForFunctionArgumentsRule.addFunctionAndArgChecker(BuiltinFunctions.IF_MISSING, null);
        InjectTypeCastForFunctionArgumentsRule.addFunctionAndArgChecker(BuiltinFunctions.IF_NULL, null);
        InjectTypeCastForFunctionArgumentsRule.addFunctionAndArgChecker(BuiltinFunctions.IF_MISSING_OR_NULL, null);
        InjectTypeCastForFunctionArgumentsRule.addFunctionAndArgChecker(BuiltinFunctions.IF_SYSTEM_NULL, null);
    }

    @FunctionalInterface
    public static interface BiIntPredicate {
        public boolean test(int var1, int var2);
    }
}

