/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.compress.colgroup;

import java.io.DataInput;
import java.io.IOException;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.colgroup.AColGroup;
import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
import org.apache.sysds.runtime.compress.colgroup.ADictBasedColGroup;
import org.apache.sysds.runtime.compress.colgroup.AOffsetsGroup;
import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
import org.apache.sysds.runtime.compress.colgroup.ColGroupUtils;
import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple;
import org.apache.sysds.runtime.compress.colgroup.IMapToDataGroup;
import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
import org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
import org.apache.sysds.runtime.compress.colgroup.offset.OffsetEmpty;
import org.apache.sysds.runtime.compress.colgroup.scheme.ConstScheme;
import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme;
import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
import org.apache.sysds.runtime.compress.estim.EstimationFactors;
import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy;
import org.apache.sysds.runtime.data.DenseBlock;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.CMOperator;
import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;

public class ColGroupConst
extends ADictBasedColGroup
implements IContainDefaultTuple,
AOffsetsGroup,
IMapToDataGroup {
    private static final long serialVersionUID = -7387793538322386611L;

    private ColGroupConst(IColIndex colIndices, IDictionary dict) {
        super(colIndices, dict);
    }

    public static AColGroup create(IColIndex colIndices, IDictionary dict) {
        if (dict == null) {
            return new ColGroupEmpty(colIndices);
        }
        if (dict.getNumberOfValues(colIndices.size()) > 1) {
            double[] nd = new double[colIndices.size()];
            for (int i = 0; i < colIndices.size(); ++i) {
                nd[i] = dict.getValue(i);
            }
            return ColGroupConst.create(colIndices, nd);
        }
        return new ColGroupConst(colIndices, dict);
    }

    public static AColGroup create(double[] values) {
        return ColGroupConst.create(ColIndexFactory.create(values.length), values);
    }

    public static AColGroup create(IColIndex cols, double value) {
        if (cols.size() == 0) {
            throw new DMLCompressionException("Invalid number of columns");
        }
        if (value == 0.0) {
            return new ColGroupEmpty(cols);
        }
        int numCols = cols.size();
        double[] values = new double[numCols];
        for (int i = 0; i < numCols; ++i) {
            values[i] = value;
        }
        return ColGroupConst.create(cols, values);
    }

    public static AColGroup create(IColIndex cols, double[] values) {
        if (cols.size() != values.length) {
            throw new DMLCompressionException("Invalid size of values compared to columns");
        }
        boolean allZero = true;
        for (double d : values) {
            if (d == 0.0) continue;
            allZero = false;
            break;
        }
        if (allZero) {
            return new ColGroupEmpty(cols);
        }
        return ColGroupConst.create(cols, (IDictionary)Dictionary.create(values));
    }

    public static AColGroup create(int numCols, IDictionary dict) {
        MatrixBlock mbd;
        if (dict instanceof MatrixBlockDictionary ? (mbd = ((MatrixBlockDictionary)dict).getMatrixBlock()).getNumColumns() != numCols && mbd.getNumRows() != 1 : numCols != dict.getValues().length) {
            throw new DMLCompressionException("Invalid construction of const column group with different number of columns in arguments");
        }
        return ColGroupConst.create(ColIndexFactory.create(numCols), dict);
    }

    public static AColGroup create(int numCols, double value) {
        if (numCols <= 0) {
            throw new DMLCompressionException("Invalid construction of constant column group with cols: " + numCols);
        }
        IColIndex colIndices = ColIndexFactory.create(numCols);
        if (value == 0.0) {
            return new ColGroupEmpty(colIndices);
        }
        return ColGroupConst.create(colIndices, value);
    }

    public double[] getValues() {
        double[] values;
        if (this.getDictionary() instanceof MatrixBlockDictionary) {
            LOG.warn((Object)"Inefficient get values for constant column group (but it is allowed)");
            MatrixBlock mb = ((MatrixBlockDictionary)this.getDictionary()).getMatrixBlock();
            if (mb.isInSparseFormat()) {
                values = new double[mb.getNumColumns()];
                SparseBlock sb = mb.getSparseBlock();
                int alen = sb.size(0);
                double[] aval = sb.values(0);
                int[] aix = sb.indexes(0);
                for (int j = 0; j < alen; ++j) {
                    values[aix[j]] = aval[j];
                }
            } else {
                values = mb.getDenseBlockValues();
            }
        } else {
            values = this._dict.getValues();
        }
        return values;
    }

    @Override
    protected void computeRowMxx(double[] c, Builtin builtin, int rl, int ru, double[] preAgg) {
        double v = preAgg[0];
        for (int i = rl; i < ru; ++i) {
            c[i] = builtin.execute(c[i], v);
        }
    }

    @Override
    public AColGroup.CompressionType getCompType() {
        return AColGroup.CompressionType.CONST;
    }

    @Override
    public AColGroup.ColGroupType getColGroupType() {
        return AColGroup.ColGroupType.CONST;
    }

    @Override
    protected void decompressToDenseBlockSparseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, SparseBlock sb) {
        int apos = sb.pos(0);
        int alen = sb.size(0);
        int[] aix = sb.indexes(0);
        double[] avals = sb.values(0);
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            double[] c = db.values(offT);
            int off = db.pos(offT) + offC;
            for (int j = apos; j < alen; ++j) {
                int n = off + this._colIndexes.get(aix[j]);
                c[n] = c[n] + avals[j];
            }
            ++i;
            ++offT;
        }
    }

    @Override
    protected void decompressToDenseBlockDenseDictionary(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) {
        if (db.isContiguous() && this._colIndexes.size() == db.getDim(1) && offC == 0) {
            this.decompressToDenseBlockAllColumnsContiguous(db, rl + offR, ru + offR);
        } else {
            this.decompressToDenseBlockGeneric(db, rl, ru, offR, offC);
        }
    }

    @Override
    protected void decompressToSparseBlockSparseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, SparseBlock sb) {
        int apos = sb.pos(0);
        int alen = sb.size(0);
        int[] aix = sb.indexes(0);
        double[] avals = sb.values(0);
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            for (int j = apos; j < alen; ++j) {
                ret.append(offT, this._colIndexes.get(aix[j]) + offC, avals[j]);
            }
            ++i;
            ++offT;
        }
    }

    @Override
    protected void decompressToSparseBlockDenseDictionary(SparseBlock ret, int rl, int ru, int offR, int offC, double[] values) {
        int nCol = this._colIndexes.size();
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            for (int j = 0; j < nCol; ++j) {
                ret.append(offT, this._colIndexes.get(j) + offC, this._dict.getValue(j));
            }
            ++i;
            ++offT;
        }
    }

    private final void decompressToDenseBlockAllColumnsContiguous(DenseBlock db, int rl, int ru) {
        double[] c = db.values(0);
        int nCol = this._colIndexes.size();
        double[] values = this._dict.getValues();
        int start = rl * nCol;
        int end = ru * nCol;
        for (int i = start; i < end; ++i) {
            int n = i;
            c[n] = c[n] + values[i % nCol];
        }
    }

    private void decompressToDenseBlockGeneric(DenseBlock db, int rl, int ru, int offR, int offC) {
        int i = rl;
        int offT = rl + offR;
        while (i < ru) {
            double[] c = db.values(offT);
            int off = db.pos(offT) + offC;
            for (int j = 0; j < this._colIndexes.size(); ++j) {
                int n = off + this._colIndexes.get(j);
                c[n] = c[n] + this._dict.getValue(j);
            }
            ++i;
            ++offT;
        }
    }

    @Override
    public double getIdx(int r, int colIdx) {
        return this._dict.getValue(colIdx);
    }

    @Override
    public AColGroup scalarOperation(ScalarOperator op) {
        return ColGroupConst.create(this._colIndexes, this._dict.applyScalarOp(op));
    }

    @Override
    public AColGroup unaryOperation(UnaryOperator op) {
        return ColGroupConst.create(this._colIndexes, this._dict.applyUnaryOp(op));
    }

    @Override
    public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe) {
        return ColGroupConst.create(this._colIndexes, this._dict.binOpLeft(op, v, this._colIndexes));
    }

    @Override
    public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, boolean isRowSafe) {
        return ColGroupConst.create(this._colIndexes, this._dict.binOpRight(op, v, this._colIndexes));
    }

    public final void addToCommon(double[] constV) {
        if (this._dict instanceof IdentityDictionary) {
            MatrixBlock mb = ((IdentityDictionary)this._dict).getMBDict().getMatrixBlock();
            if (mb.isInSparseFormat()) {
                this.addToCommonSparse(constV, mb.getSparseBlock());
            } else {
                this.addToCommonDense(constV, mb.getDenseBlockValues());
            }
        } else if (this._dict instanceof MatrixBlockDictionary) {
            MatrixBlock mb = ((MatrixBlockDictionary)this._dict).getMatrixBlock();
            if (mb.isInSparseFormat()) {
                this.addToCommonSparse(constV, mb.getSparseBlock());
            } else {
                this.addToCommonDense(constV, mb.getDenseBlockValues());
            }
        } else {
            this.addToCommonDense(constV, this._dict.getValues());
        }
    }

    private final void addToCommonDense(double[] constV, double[] values) {
        for (int i = 0; i < this._colIndexes.size(); ++i) {
            int n = this._colIndexes.get(i);
            constV[n] = constV[n] + values[i];
        }
    }

    private final void addToCommonSparse(double[] constV, SparseBlock sb) {
        int alen = sb.size(0);
        int[] aix = sb.indexes(0);
        double[] aval = sb.values(0);
        for (int i = 0; i < alen; ++i) {
            int n = this._colIndexes.get(aix[i]);
            constV[n] = constV[n] + aval[i];
        }
    }

    @Override
    protected double computeMxx(double c, Builtin builtin) {
        return this._dict.aggregate(c, builtin);
    }

    @Override
    protected void computeColMxx(double[] c, Builtin builtin) {
        this._dict.aggregateCols(c, builtin, this._colIndexes);
    }

    @Override
    protected void computeSum(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sum(new int[]{nRows}, this._colIndexes.size());
    }

    @Override
    public void computeColSums(double[] c, int nRows) {
        this._dict.colSum(c, new int[]{nRows}, this._colIndexes);
    }

    @Override
    protected void computeSumSq(double[] c, int nRows) {
        c[0] = c[0] + this._dict.sumSq(new int[]{nRows}, this._colIndexes.size());
    }

    @Override
    protected void computeColSumsSq(double[] c, int nRows) {
        this._dict.colSumSq(c, new int[]{nRows}, this._colIndexes);
    }

    @Override
    protected void computeRowSums(double[] c, int rl, int ru, double[] preAgg) {
        double vals = preAgg[0];
        int rix = rl;
        while (rix < ru) {
            int n = rix++;
            c[n] = c[n] + vals;
        }
    }

    @Override
    public int getNumValues() {
        return 1;
    }

    @Override
    public void tsmm(double[] result, int numColumns, int nRows) {
        ColGroupConst.tsmm(result, numColumns, new int[]{nRows}, this._dict, this._colIndexes);
    }

    @Override
    public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock result, int rl, int ru, int cl, int cu) {
        LOG.warn((Object)"Do not use leftMultByMatrixNoPreAgg on ColGroupConst");
        double[] rowSum = cl != 0 && cu != matrix.getNumColumns() ? CLALibLeftMultBy.rowSum(matrix, rl, ru, cl, cu) : matrix.rowSum().getDenseBlockValues();
        this.leftMultByRowSum(rowSum, result, rl, ru);
    }

    @Override
    public void leftMultByAColGroup(AColGroup lhs, MatrixBlock result, int nRows) {
        LOG.warn((Object)"Should never call leftMultByMatrixByAColGroup on ColGroupConst");
        double[] rowSum = new double[result.getNumRows()];
        lhs.computeColSums(rowSum, nRows);
        this.leftMultByRowSum(rowSum, result, 0, result.getNumRows());
    }

    private void leftMultByRowSum(double[] rowSum, MatrixBlock result, int rl, int ru) {
        if (this._dict instanceof MatrixBlockDictionary) {
            MatrixBlock mb = ((MatrixBlockDictionary)this._dict).getMatrixBlock();
            if (mb.isInSparseFormat()) {
                ColGroupUtils.outerProduct(rowSum, mb.getSparseBlock(), this._colIndexes, result.getDenseBlockValues(), result.getNumColumns(), rl, ru);
            } else {
                ColGroupUtils.outerProduct(rowSum, this._dict.getValues(), this._colIndexes, result.getDenseBlockValues(), result.getNumColumns(), rl, ru);
            }
        } else {
            ColGroupUtils.outerProduct(rowSum, this._dict.getValues(), this._colIndexes, result.getDenseBlockValues(), result.getNumColumns(), rl, ru);
        }
    }

    @Override
    public void tsmmAColGroup(AColGroup other, MatrixBlock result) {
        throw new DMLCompressionException("Should not be called");
    }

    @Override
    protected AColGroup sliceSingleColumn(int idx) {
        IColIndex colIndexes = ColIndexFactory.create(1);
        double v = this._dict.getValue(idx);
        if (v == 0.0) {
            return new ColGroupEmpty(colIndexes);
        }
        Dictionary retD = Dictionary.create(new double[]{this._dict.getValue(idx)});
        return ColGroupConst.create(colIndexes, (IDictionary)retD);
    }

    @Override
    protected AColGroup sliceMultiColumns(int idStart, int idEnd, IColIndex outputCols) {
        IDictionary retD = this._dict.sliceOutColumnRange(idStart, idEnd, this._colIndexes.size());
        return ColGroupConst.create(outputCols, retD);
    }

    @Override
    public boolean containsValue(double pattern) {
        return this._dict.containsValue(pattern);
    }

    @Override
    public long getNumberNonZeros(int nRows) {
        return this._dict.getNumberNonZeros(new int[]{nRows}, this._colIndexes.size());
    }

    @Override
    public AColGroup replace(double pattern, double replace) {
        IDictionary replaced = this._dict.replace(pattern, replace, this._colIndexes.size());
        return ColGroupConst.create(this._colIndexes, replaced);
    }

    @Override
    protected void computeProduct(double[] c, int nRows) {
        this._dict.product(c, new int[]{nRows}, this._colIndexes.size());
    }

    @Override
    protected void computeRowProduct(double[] c, int rl, int ru, double[] preAgg) {
        double v = preAgg[0];
        int rix = rl;
        while (rix < ru) {
            int n = rix++;
            c[n] = c[n] * v;
        }
    }

    @Override
    protected void computeColProduct(double[] c, int nRows) {
        this._dict.colProduct(c, new int[]{nRows}, this._colIndexes);
    }

    @Override
    protected double[] preAggSumRows() {
        return this._dict.sumAllRowsToDouble(this._colIndexes.size());
    }

    @Override
    protected double[] preAggSumSqRows() {
        return this._dict.sumAllRowsToDoubleSq(this._colIndexes.size());
    }

    @Override
    protected double[] preAggProductRows() {
        return this._dict.productAllRowsToDouble(this._colIndexes.size());
    }

    @Override
    protected double[] preAggBuiltinRows(Builtin builtin) {
        return this._dict.aggregateRows(builtin, this._colIndexes.size());
    }

    @Override
    public CM_COV_Object centralMoment(CMOperator op, int nRows) {
        CM_COV_Object ret = new CM_COV_Object();
        op.fn.execute(ret, this._dict.getValue(0), nRows);
        return ret;
    }

    @Override
    public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int nRows) {
        IDictionary d = this._dict.rexpandCols(max, ignore, cast, this._colIndexes.size());
        if (d == null) {
            return ColGroupEmpty.create(max);
        }
        return ColGroupConst.create(max, d);
    }

    @Override
    public double getCost(ComputationCostEstimator e, int nRows) {
        int nCols = this.getNumCols();
        return e.getCost(nRows, 1, nCols, 1, 1.0);
    }

    protected AColGroup copyAndSet(IColIndex colIndexes, double[] newDictionary) {
        return ColGroupConst.create(colIndexes, (IDictionary)Dictionary.create(newDictionary));
    }

    @Override
    protected AColGroup copyAndSet(IColIndex colIndexes, IDictionary newDictionary) {
        return ColGroupConst.create(colIndexes, newDictionary);
    }

    @Override
    protected AColGroup allocateRightMultiplication(MatrixBlock right, IColIndex colIndexes, IDictionary preAgg) {
        if (colIndexes != null && preAgg != null) {
            return ColGroupConst.create(colIndexes, preAgg);
        }
        return null;
    }

    public static ColGroupConst read(DataInput in) throws IOException {
        IColIndex cols = ColIndexFactory.read(in);
        IDictionary dict = DictionaryFactory.read(in);
        return new ColGroupConst(cols, dict);
    }

    @Override
    public AColGroup sliceRows(int rl, int ru) {
        return this;
    }

    @Override
    public AColGroup append(AColGroup g) {
        if (g instanceof ColGroupConst && g._colIndexes.size() == this._colIndexes.size() && ((ColGroupConst)g)._dict.equals(this._dict)) {
            return this;
        }
        return null;
    }

    @Override
    public AColGroup appendNInternal(AColGroup[] g, int blen, int rlen) {
        for (int i = 0; i < g.length; ++i) {
            AColGroup gs = g[i];
            if (!this._colIndexes.equals(gs._colIndexes)) {
                throw new DMLCompressionException("Invalid columns not matching " + gs._colIndexes + " " + this._colIndexes);
            }
            if (gs instanceof ColGroupConst) {
                if (this._dict.equals(((ColGroupConst)gs)._dict)) continue;
                throw new NotImplementedException("Appending const not equivalent");
            }
            if (gs instanceof ColGroupEmpty) {
                throw new NotImplementedException("Appending empty and const");
            }
            return gs.appendNInternal(g, blen, rlen);
        }
        return this;
    }

    @Override
    public ICLAScheme getCompressionScheme() {
        return ConstScheme.create(this);
    }

    @Override
    public AColGroup recompress() {
        return this;
    }

    @Override
    public CompressedSizeInfoColGroup getCompressionInfo(int nRow) {
        EstimationFactors ef = new EstimationFactors(1, 1, 1, this._dict.getSparsity());
        return new CompressedSizeInfoColGroup(this._colIndexes, ef, this.estimateInMemorySize(), AColGroup.CompressionType.CONST, this.getEncoding());
    }

    @Override
    public IEncode getEncoding() {
        return EncodingFactory.create(this);
    }

    @Override
    public boolean sameIndexStructure(AColGroupCompressed that) {
        return that instanceof ColGroupEmpty || that instanceof ColGroupConst;
    }

    @Override
    public double[] getDefaultTuple() {
        return this._dict.getValues();
    }

    @Override
    protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) {
        return ColGroupConst.create(newColIndex, this._dict.reorder(reordering));
    }

    @Override
    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append(super.toString());
        sb.append(String.format("\n%15s", "Values: "));
        sb.append(this._dict.getClass().getSimpleName());
        sb.append(this._dict.getString(this._colIndexes.size()));
        return sb.toString();
    }

    @Override
    public AOffset getOffsets() {
        return new OffsetEmpty();
    }

    @Override
    public AMapToData getMapToData() {
        return MapToFactory.create(0, 0);
    }
}

