/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import org.apache.sysds.runtime.DMLScriptException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
import org.apache.sysds.runtime.functionobjects.Multiply;
import org.apache.sysds.runtime.functionobjects.Plus;
import org.apache.sysds.runtime.functionobjects.ReduceAll;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;

public abstract class ColGroupCompressed
extends AColGroup {
    private static final long serialVersionUID = 6219835795420081223L;

    protected ColGroupCompressed() {
    }

    protected ColGroupCompressed(int[] colIndices) {
        super(colIndices);
    }

    @Override
    public abstract double[] getValues();

    @Override
    public abstract boolean isLossy();

    protected abstract double computeMxx(double var1, Builtin var3);

    protected abstract void computeColMxx(double[] var1, Builtin var2);

    protected abstract void computeSum(double[] var1, int var2, boolean var3);

    protected abstract void computeRowSums(double[] var1, boolean var2, int var3, int var4);

    protected abstract void computeColSums(double[] var1, int var2, boolean var3);

    protected abstract void computeRowMxx(double[] var1, Builtin var2, int var3, int var4);

    protected abstract void computeProduct(double[] var1, int var2);

    protected abstract void computeRowProduct(double[] var1, int var2, int var3);

    protected abstract void computeColProduct(double[] var1, int var2);

    @Override
    public double getMin() {
        return this.computeMxx(Double.POSITIVE_INFINITY, Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MIN));
    }

    @Override
    public double getMax() {
        return this.computeMxx(Double.NEGATIVE_INFINITY, Builtin.getBuiltinFnObject(Builtin.BuiltinCode.MAX));
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        this.computeColSums(c, nRows, false);
    }

    /*
     * Enabled force condition propagation
     * Lifted jumps to return sites
     */
    @Override
    public final void unaryAggregateOperations(AggregateUnaryOperator op, double[] c, int nRows, int rl, int ru) {
        ValueFunction fn = op.aggOp.increOp.fn;
        if (fn instanceof Plus || fn instanceof KahanPlus || fn instanceof KahanPlusSq) {
            boolean square = fn instanceof KahanPlusSq;
            if (op.indexFn instanceof ReduceAll) {
                this.computeSum(c, nRows, square);
                return;
            } else if (op.indexFn instanceof ReduceCol) {
                this.computeRowSums(c, square, rl, ru);
                return;
            } else {
                if (!(op.indexFn instanceof ReduceRow)) return;
                this.computeColSums(c, nRows, square);
            }
            return;
        } else if (fn instanceof Multiply) {
            if (op.indexFn instanceof ReduceAll) {
                this.computeProduct(c, nRows);
                return;
            } else if (op.indexFn instanceof ReduceCol) {
                this.computeRowProduct(c, rl, ru);
                return;
            } else {
                if (!(op.indexFn instanceof ReduceRow)) return;
                this.computeColProduct(c, nRows);
            }
            return;
        } else {
            if (!(fn instanceof Builtin)) throw new DMLScriptException("Unknown UnaryAggregate operator on CompressedMatrixBlock");
            Builtin bop = (Builtin)fn;
            Builtin.BuiltinCode bopC = bop.getBuiltinCode();
            if (bopC != Builtin.BuiltinCode.MAX && bopC != Builtin.BuiltinCode.MIN) throw new DMLScriptException("unsupported builtin type: " + bop);
            if (op.indexFn instanceof ReduceAll) {
                c[0] = this.computeMxx(c[0], bop);
                return;
            } else if (op.indexFn instanceof ReduceCol) {
                this.computeRowMxx(c, bop, rl, ru);
                return;
            } else {
                if (!(op.indexFn instanceof ReduceRow)) return;
                this.computeColMxx(c, bop);
            }
        }
    }

    @Override
    public final void tsmm(MatrixBlock ret, int nRows) {
        double[] result = ret.getDenseBlockValues();
        int numColumns = ret.getNumColumns();
        this.tsmm(result, numColumns, nRows);
    }

    protected abstract void tsmm(double[] var1, int var2, int var3);

    protected static void tsmm(double[] result, int numColumns, int[] counts, ADictionary dict, int[] colIndexes) {
        if ((dict = dict.getAsMatrixBlockDictionary(colIndexes.length)) instanceof MatrixBlockDictionary) {
            MatrixBlockDictionary mbd = (MatrixBlockDictionary)dict;
            MatrixBlock mb = mbd.getMatrixBlock();
            if (mb.isEmpty()) {
                return;
            }
            if (mb.isInSparseFormat()) {
                ColGroupCompressed.tsmmSparse(result, numColumns, mb.getSparseBlock(), counts, colIndexes);
            } else {
                ColGroupCompressed.tsmmDense(result, numColumns, mb.getDenseBlockValues(), counts, colIndexes);
            }
        } else {
            ColGroupCompressed.tsmmDense(result, numColumns, dict.getValues(), counts, colIndexes);
        }
    }

    protected static void tsmmDense(double[] result, int numColumns, double[] values, int[] counts, int[] colIndexes) {
        if (values == null) {
            return;
        }
        int nCol = colIndexes.length;
        int nRow = values.length / colIndexes.length;
        for (int k = 0; k < nRow; ++k) {
            int offTmp = nCol * k;
            int scale = counts[k];
            for (int i = 0; i < nCol; ++i) {
                int offRet = numColumns * colIndexes[i];
                double v = values[offTmp + i] * (double)scale;
                if (v == 0.0) continue;
                for (int j = i; j < nCol; ++j) {
                    int n = offRet + colIndexes[j];
                    result[n] = result[n] + v * values[offTmp + j];
                }
            }
        }
    }

    protected static void tsmmSparse(double[] result, int numColumns, SparseBlock sb, int[] counts, int[] colIndexes) {
        for (int row = 0; row < sb.numRows(); ++row) {
            if (sb.isEmpty(row)) continue;
            int apos = sb.pos(row);
            int alen = sb.size(row);
            int[] aix = sb.indexes(row);
            double[] avals = sb.values(row);
            for (int i = apos; i < apos + alen; ++i) {
                int offRet = colIndexes[aix[i]] * numColumns;
                double val = avals[i] * (double)counts[row];
                for (int j = i; j < apos + alen; ++j) {
                    int n = offRet + colIndexes[aix[j]];
                    result[n] = result[n] + val * avals[j];
                }
            }
        }
    }
}

