/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.hops;

import java.util.ArrayList;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.HopsException;
import org.apache.sysds.hops.LiteralOp;
import org.apache.sysds.hops.MemoTable;
import org.apache.sysds.hops.MultiThreadedHop;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.CumulativeOffsetBinary;
import org.apache.sysds.lops.CumulativePartialAggregate;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.lops.SortKeys;
import org.apache.sysds.lops.Unary;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;

public class UnaryOp
extends MultiThreadedHop {
    private static final boolean ALLOW_CUMAGG_BROADCAST = true;
    private static final boolean ALLOW_CUMAGG_CACHING = false;
    private Types.OpOp1 _op = null;

    private UnaryOp() {
    }

    public UnaryOp(String l, Types.DataType dt, Types.ValueType vt, Types.OpOp1 o, Hop inp) {
        super(l, dt, vt);
        this.getInput().add(inp);
        inp.getParent().add(this);
        this._op = o;
        this.refreshSizeInformation();
    }

    @Override
    public void checkArity() {
        HopsException.check(this._input.size() == 1, this, "should have arity 1 but has arity %d", this._input.size());
    }

    public Types.OpOp1 getOp() {
        return this._op;
    }

    @Override
    public String getOpString() {
        String s = new String("");
        s = s + "u(" + this._op.toString().toLowerCase() + ")";
        return s;
    }

    @Override
    public boolean isGPUEnabled() {
        boolean isScalar;
        if (!DMLScript.USE_ACCELERATOR) {
            return false;
        }
        boolean bl = isScalar = this.getDataType() == Types.DataType.SCALAR || this._op == Types.OpOp1.CAST_AS_MATRIX && this.getInput().get(0).getDataType() == Types.DataType.SCALAR || this._op == Types.OpOp1.CAST_AS_FRAME && this.getInput().get(0).getDataType() == Types.DataType.SCALAR;
        if (!isScalar) {
            switch (this._op) {
                case EXP: 
                case SQRT: 
                case LOG: 
                case ABS: 
                case ROUND: 
                case FLOOR: 
                case CEIL: 
                case SIN: 
                case COS: 
                case TAN: 
                case ASIN: 
                case ACOS: 
                case ATAN: 
                case SINH: 
                case COSH: 
                case TANH: 
                case SIGN: 
                case SIGMOID: 
                case CUMSUM: 
                case CUMPROD: 
                case CUMMIN: 
                case CUMMAX: 
                case CUMSUMPROD: {
                    return true;
                }
            }
            return false;
        }
        return false;
    }

    @Override
    public boolean isMultiThreadedOpType() {
        return this.isCumulativeUnaryOperation() || this.isExpensiveUnaryOperation();
    }

    @Override
    public Lop constructLops() {
        if (this.getLops() != null) {
            return this.getLops();
        }
        try {
            Hop input = this.getInput().get(0);
            if (this.getDataType() == Types.DataType.SCALAR || this._op == Types.OpOp1.CAST_AS_MATRIX && this.getInput().get(0).getDataType() == Types.DataType.SCALAR || this._op == Types.OpOp1.CAST_AS_FRAME && this.getInput().get(0).getDataType() == Types.DataType.SCALAR) {
                if (this._op == Types.OpOp1.IQM) {
                    Lop iqmLop = this.constructLopsIQM();
                    this.setLops(iqmLop);
                } else if (this._op == Types.OpOp1.MEDIAN) {
                    Lop medianLop = this.constructLopsMedian();
                    this.setLops(medianLop);
                } else {
                    UnaryCP unary1 = new UnaryCP(input.constructLops(), this._op, this.getDataType(), this.getValueType());
                    this.setOutputDimensions(unary1);
                    this.setLineNumbers(unary1);
                    this.setLops(unary1);
                }
            } else {
                Types.ExecType et = this.optFindExecType();
                if (this.isCumulativeUnaryOperation() && et != Types.ExecType.CP && et != Types.ExecType.GPU) {
                    Lop cumsumLop = this.constructLopsSparkCumulativeUnary();
                    this.setLops(cumsumLop);
                } else {
                    int k = this.isCumulativeUnaryOperation() || this.isExpensiveUnaryOperation() ? OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads) : 1;
                    Unary unary1 = new Unary(input.constructLops(), this._op, this.getDataType(), this.getValueType(), et, k, false);
                    this.setOutputDimensions(unary1);
                    this.setLineNumbers(unary1);
                    this.setLops(unary1);
                }
            }
        }
        catch (Exception e) {
            throw new HopsException(this.printErrorLocation() + "error constructing Lops for UnaryOp Hop -- \n ", e);
        }
        this.constructAndSetLopsDataFlowProperties();
        return this.getLops();
    }

    private Lop constructLopsMedian() {
        Types.ExecType et = this.optFindExecType();
        SortKeys sort = SortKeys.constructSortByValueLop(this.getInput().get(0).constructLops(), SortKeys.OperationTypes.WithoutWeights, Types.DataType.MATRIX, Types.ValueType.FP64, et);
        sort.getOutputParameters().setDimensions(this.getInput().get(0).getDim1(), this.getInput().get(0).getDim2(), this.getInput().get(0).getBlocksize(), this.getInput().get(0).getNnz());
        PickByCount pick = new PickByCount(sort, Data.createLiteralLop(Types.ValueType.FP64, Double.toString(0.5)), this.getDataType(), this.getValueType(), PickByCount.OperationTypes.MEDIAN, et, true);
        pick.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getBlocksize(), this.getNnz());
        this.setLineNumbers(pick);
        this.setLops(pick);
        return pick;
    }

    private Lop constructLopsIQM() {
        Types.ExecType et = this.optFindExecType();
        Hop input = this.getInput().get(0);
        SortKeys sort = SortKeys.constructSortByValueLop(input.constructLops(), SortKeys.OperationTypes.WithoutWeights, Types.DataType.MATRIX, Types.ValueType.FP64, et);
        sort.getOutputParameters().setDimensions(input.getDim1(), input.getDim2(), input.getBlocksize(), input.getNnz());
        PickByCount pick = new PickByCount(sort, null, this.getDataType(), this.getValueType(), PickByCount.OperationTypes.IQM, et, true);
        pick.getOutputParameters().setDimensions(this.getDim1(), this.getDim2(), this.getBlocksize(), this.getNnz());
        this.setLineNumbers(pick);
        return pick;
    }

    private Lop constructLopsSparkCumulativeUnary() {
        Hop input = this.getInput().get(0);
        long rlen = input.getDim1();
        long clen = input.getDim2();
        long blen = input.getBlocksize();
        boolean force = !this.dimsKnown() || this._etypeForced == Types.ExecType.SPARK;
        Types.AggOp aggtype = this.getCumulativeAggType();
        Lop X = input.constructLops();
        if (rlen > 0L && clen > 0L && rlen <= blen) {
            Lop offset = HopRewriteUtils.createDataGenOpByVal(new LiteralOp(1L), new LiteralOp(clen), null, Types.DataType.MATRIX, Types.ValueType.FP64, this.getCumulativeInitValue()).constructLops();
            return this.constructCumOffBinary(X, offset, aggtype, rlen, clen, blen);
        }
        Lop TEMP = X;
        ArrayList<Lop> DATA = new ArrayList<Lop>();
        int level = 0;
        while ((double)(2L * OptimizerUtils.estimateSize(TEMP.getOutputParameters().getNumRows(), clen) + OptimizerUtils.estimateSize(1L, clen)) > OptimizerUtils.getLocalMemBudget() && TEMP.getOutputParameters().getNumRows() > 1L || force) {
            DATA.add(TEMP);
            long rlenAgg = (long)Math.ceil((double)TEMP.getOutputParameters().getNumRows() / (double)blen);
            CumulativePartialAggregate preagg = new CumulativePartialAggregate(TEMP, Types.DataType.MATRIX, Types.ValueType.FP64, aggtype, Types.ExecType.SPARK);
            preagg.getOutputParameters().setDimensions(rlenAgg, clen, blen, -1L);
            this.setLineNumbers(preagg);
            TEMP = preagg;
            ++level;
            force = false;
        }
        if (TEMP.getOutputParameters().getNumRows() != 1L) {
            int k = OptimizerUtils.getConstrainedNumThreads(this._maxNumThreads);
            Unary unary1 = new Unary(TEMP, this._op, Types.DataType.MATRIX, Types.ValueType.FP64, Types.ExecType.CP, k, true);
            unary1.getOutputParameters().setDimensions(TEMP.getOutputParameters().getNumRows(), clen, blen, -1L);
            this.setLineNumbers(unary1);
            TEMP = unary1;
        }
        while (level-- > 0) {
            TEMP = this.constructCumOffBinary((Lop)DATA.get(level), TEMP, aggtype, rlen, clen, blen);
        }
        return TEMP;
    }

    private Lop constructCumOffBinary(Lop data, Lop offset, Types.AggOp aggtype, long rlen, long clen, long blen) {
        double initValue = this.getCumulativeInitValue();
        boolean broadcast = OptimizerUtils.checkSparkBroadcastMemoryBudget(OptimizerUtils.estimateSize(offset.getOutputParameters().getNumRows(), offset.getOutputParameters().getNumCols()));
        CumulativeOffsetBinary binary = new CumulativeOffsetBinary(data, offset, Types.DataType.MATRIX, Types.ValueType.FP64, initValue, broadcast, aggtype, Types.ExecType.SPARK);
        binary.getOutputParameters().setDimensions(rlen, clen, blen, -1L);
        this.setLineNumbers(binary);
        return binary;
    }

    private Types.AggOp getCumulativeAggType() {
        switch (this._op) {
            case CUMSUM: {
                return Types.AggOp.SUM;
            }
            case CUMPROD: {
                return Types.AggOp.PROD;
            }
            case CUMMIN: {
                return Types.AggOp.MIN;
            }
            case CUMMAX: {
                return Types.AggOp.MAX;
            }
            case CUMSUMPROD: {
                return Types.AggOp.SUM_PROD;
            }
        }
        return null;
    }

    private double getCumulativeInitValue() {
        switch (this._op) {
            case CUMSUM: 
            case CUMSUMPROD: {
                return 0.0;
            }
            case CUMPROD: {
                return 1.0;
            }
            case CUMMIN: {
                return Double.POSITIVE_INFINITY;
            }
            case CUMMAX: {
                return Double.NEGATIVE_INFINITY;
            }
        }
        return Double.NaN;
    }

    @Override
    public void computeMemEstimate(MemoTable memo) {
        super.computeMemEstimate(memo);
        if (this.isMetadataOperation()) {
            this._memEstimate = 4.0;
        }
    }

    @Override
    protected double computeOutputMemEstimate(long dim1, long dim2, long nnz) {
        double sparsity = -1.0;
        sparsity = this.isGPUEnabled() ? 1.0 : OptimizerUtils.getSparsity(dim1, dim2, nnz);
        return OptimizerUtils.estimateSizeExactSparsity(dim1, dim2, sparsity);
    }

    @Override
    protected double computeIntermediateMemEstimate(long dim1, long dim2, long nnz) {
        double ret = 0.0;
        if (this._op == Types.OpOp1.IQM || this._op == Types.OpOp1.MEDIAN) {
            ret = this.getInput().get(0).getMemEstimate() * 3.0;
        } else if (this.isCumulativeUnaryOperation()) {
            ret += (double)MatrixBlock.estimateSizeSparseInMemory(dim1, dim2, 0.4 - UtilFunctions.DOUBLE_EPS);
        }
        if (this.isGPUEnabled()) {
            ret += (double)OptimizerUtils.estimateSize(dim1, dim2);
        }
        return ret;
    }

    @Override
    protected DataCharacteristics inferOutputCharacteristics(MemoTable memo) {
        DataCharacteristics dc = memo.getAllInputStats(this.getInput().get(0));
        MatrixCharacteristics ret = null;
        if (dc.dimsKnown()) {
            ret = this._op == Types.OpOp1.ABS || this._op == Types.OpOp1.COS || this._op == Types.OpOp1.SIN || this._op == Types.OpOp1.TAN || this._op == Types.OpOp1.ACOS || this._op == Types.OpOp1.ASIN || this._op == Types.OpOp1.ATAN || this._op == Types.OpOp1.COSH || this._op == Types.OpOp1.SINH || this._op == Types.OpOp1.TANH || this._op == Types.OpOp1.SQRT || this._op == Types.OpOp1.ROUND || this._op == Types.OpOp1.SPROP || this._op == Types.OpOp1.COMPRESS || this._op == Types.OpOp1.DECOMPRESS ? new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, dc.getNonZeros()) : (this._op == Types.OpOp1.CUMSUMPROD ? new MatrixCharacteristics(dc.getRows(), 1L, -1, -1L) : new MatrixCharacteristics(dc.getRows(), dc.getCols(), -1, -1L));
        }
        return ret;
    }

    @Override
    public boolean allowsAllExecTypes() {
        return true;
    }

    private boolean isInMemoryOperation() {
        return this._op == Types.OpOp1.INVERSE;
    }

    public boolean isCumulativeUnaryOperation() {
        return this._op == Types.OpOp1.CUMSUM || this._op == Types.OpOp1.CUMPROD || this._op == Types.OpOp1.CUMMIN || this._op == Types.OpOp1.CUMMAX || this._op == Types.OpOp1.CUMSUMPROD;
    }

    public boolean isCastUnaryOperation() {
        return this._op == Types.OpOp1.CAST_AS_MATRIX || this._op == Types.OpOp1.CAST_AS_SCALAR || this._op == Types.OpOp1.CAST_AS_FRAME || this._op == Types.OpOp1.CAST_AS_BOOLEAN || this._op == Types.OpOp1.CAST_AS_DOUBLE || this._op == Types.OpOp1.CAST_AS_INT;
    }

    public boolean isExpensiveUnaryOperation() {
        return this._op == Types.OpOp1.EXP || this._op == Types.OpOp1.LOG || this._op == Types.OpOp1.SIGMOID || this._op == Types.OpOp1.COMPRESS || this._op == Types.OpOp1.DECOMPRESS;
    }

    public boolean isMetadataOperation() {
        return this._op == Types.OpOp1.NROW || this._op == Types.OpOp1.NCOL || this._op == Types.OpOp1.LENGTH || this._op == Types.OpOp1.EXISTS || this._op == Types.OpOp1.LINEAGE;
    }

    @Override
    protected Types.ExecType optFindExecType(boolean transitive) {
        this.checkAndSetForcedPlatform();
        if (this._etypeForced != null) {
            this._etype = this._etypeForced;
        } else {
            this._etype = OptimizerUtils.isMemoryBasedOptLevel() ? this.findExecTypeByMemEstimate() : (this.getInput().get(0).areDimsBelowThreshold() || this.getInput().get(0).isVector() || this.isInMemoryOperation() ? Types.ExecType.CP : Types.ExecType.SPARK);
            this.checkAndSetInvalidCPDimsAndSize();
        }
        if (this._etype == Types.ExecType.CP && this._etypeForced != Types.ExecType.CP && this.getInput().get(0).optFindExecType() == Types.ExecType.SPARK && this.getDataType().isMatrix() && !this.isCumulativeUnaryOperation() && !this.isCastUnaryOperation() && this._op != Types.OpOp1.MEDIAN && this._op != Types.OpOp1.IQM && !(this.getInput().get(0) instanceof DataOp) && this.getInput().get(0).getParent().size() == 1) {
            this._etype = Types.ExecType.SPARK;
        }
        this.setRequiresRecompileIfNecessary();
        if (this._op == Types.OpOp1.PRINT || this._op == Types.OpOp1.ASSERT || this._op == Types.OpOp1.STOP || this._op == Types.OpOp1.TYPEOF || this._op == Types.OpOp1.INVERSE || this._op == Types.OpOp1.EIGEN || this._op == Types.OpOp1.CHOLESKY || this._op == Types.OpOp1.SVD || this.getInput().get(0).getDataType() == Types.DataType.LIST || this.isMetadataOperation()) {
            this._etype = Types.ExecType.CP;
        } else {
            this.updateETFed();
            this.setRequiresRecompileIfNecessary();
        }
        return this._etype;
    }

    @Override
    public void refreshSizeInformation() {
        Hop input = this.getInput().get(0);
        if (this.getDataType() != Types.DataType.SCALAR) {
            if ((this._op == Types.OpOp1.CAST_AS_MATRIX || this._op == Types.OpOp1.CAST_AS_FRAME || this._op == Types.OpOp1.CAST_AS_SCALAR) && input.getDataType() == Types.DataType.LIST) {
                this.setDim1(input.getLength() > 1L ? input.getLength() : -1L);
                this.setDim2(input.getLength() > 1L ? 1L : -1L);
            } else if ((this._op == Types.OpOp1.CAST_AS_MATRIX || this._op == Types.OpOp1.CAST_AS_FRAME) && input.getDataType() == Types.DataType.SCALAR) {
                this.setDim1(1L);
                this.setDim2(1L);
            } else if (this._op == Types.OpOp1.CUMSUMPROD) {
                this.setDim1(input.getDim1());
                this.setDim2(1L);
            } else if (this._op == Types.OpOp1.TYPEOF || this._op == Types.OpOp1.DETECTSCHEMA || this._op == Types.OpOp1.COLNAMES) {
                this.setDim1(1L);
                this.setDim2(input.getDim2());
            } else {
                this.setDim1(input.getDim1());
                this.setDim2(input.getDim2());
                if (this._op == Types.OpOp1.ABS || this._op == Types.OpOp1.SIN || this._op == Types.OpOp1.TAN || this._op == Types.OpOp1.SINH || this._op == Types.OpOp1.TANH || this._op == Types.OpOp1.ASIN || this._op == Types.OpOp1.ATAN || this._op == Types.OpOp1.SQRT || this._op == Types.OpOp1.ROUND || this._op == Types.OpOp1.SPROP || this._op == Types.OpOp1.COMPRESS || this._op == Types.OpOp1.DECOMPRESS) {
                    this.setNnz(input.getNnz());
                }
                if (input._compressedOutput && this._op != Types.OpOp1.DECOMPRESS) {
                    this.setCompressedOutput(true);
                    this.setCompressedSize(input.compressedSize() * 2L);
                }
            }
        }
    }

    @Override
    public Object clone() throws CloneNotSupportedException {
        UnaryOp ret = new UnaryOp();
        ret.clone(this, false);
        ret._op = this._op;
        return ret;
    }

    @Override
    public boolean compare(Hop that) {
        if (!(that instanceof UnaryOp)) {
            return false;
        }
        if (this._op == Types.OpOp1.PRINT) {
            return false;
        }
        UnaryOp that2 = (UnaryOp)that;
        return this._op == that2._op && this.getInput().get(0) == that2.getInput().get(0);
    }
}

