/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.spark;

import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.PlusMultiply;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.instructions.spark.utils.RDDAggregateUtils;
import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.OperationsOnMatrixValues;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import scala.Tuple2;

public class CumulativeAggregateSPInstruction
extends AggregateUnarySPInstruction {
    private CumulativeAggregateSPInstruction(AggregateUnaryOperator op, CPOperand in1, CPOperand out, String opcode, String istr) {
        super(SPInstruction.SPType.CumsumAggregate, op, null, in1, out, null, opcode, istr);
    }

    public static CumulativeAggregateSPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(parts, 2);
        String opcode = parts[0];
        CPOperand in1 = new CPOperand(parts[1]);
        CPOperand out = new CPOperand(parts[2]);
        AggregateUnaryOperator aggun = InstructionUtils.parseCumulativeAggregateUnaryOperator(opcode);
        return new CumulativeAggregateSPInstruction(aggun, in1, out, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        DataCharacteristics mc = sec.getDataCharacteristics(this.input1.getName());
        MatrixCharacteristics mcOut = new MatrixCharacteristics(mc);
        long rlen = mc.getRows();
        int blen = mc.getBlocksize();
        ((DataCharacteristics)mcOut).setRows((long)Math.ceil((double)rlen / (double)blen));
        JavaPairRDD<MatrixIndexes, MatrixBlock> in = sec.getBinaryMatrixBlockRDDHandleForVariable(this.input1.getName());
        AggregateUnaryOperator auop = (AggregateUnaryOperator)this._optr;
        JavaPairRDD<MatrixIndexes, MatrixBlock> out = in.mapToPair((PairFunction)new RDDCumAggFunction(auop, rlen, blen));
        int numParts = SparkUtils.getNumPreferredPartitions(mcOut);
        int minPar = (int)Math.min((long)SparkExecutionContext.getDefaultParallelism(true), ((DataCharacteristics)mcOut).getNumBlocks());
        out = RDDAggregateUtils.mergeByKey(out, Math.max(numParts, minPar), false);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        sec.addLineageRDD(this.output.getName(), this.input1.getName());
        sec.getDataCharacteristics(this.output.getName()).set(mcOut);
    }

    private static class RDDCumAggFunction
    implements PairFunction<Tuple2<MatrixIndexes, MatrixBlock>, MatrixIndexes, MatrixBlock> {
        private static final long serialVersionUID = 11324676268945117L;
        private final AggregateUnaryOperator _op;
        private UnaryOperator _uop = null;
        private final long _rlen;
        private final int _blen;

        public RDDCumAggFunction(AggregateUnaryOperator op, long rlen, int blen) {
            this._op = op;
            this._rlen = rlen;
            this._blen = blen;
        }

        public Tuple2<MatrixIndexes, MatrixBlock> call(Tuple2<MatrixIndexes, MatrixBlock> arg0) throws Exception {
            MatrixIndexes ixIn = (MatrixIndexes)arg0._1();
            MatrixBlock blkIn = (MatrixBlock)arg0._2();
            MatrixIndexes ixOut = new MatrixIndexes();
            MatrixBlock blkOut = new MatrixBlock();
            AggregateUnaryOperator aop = this._op;
            if (aop.aggOp.increOp.fn instanceof PlusMultiply) {
                aop.indexFn.execute(ixIn, ixOut);
                if (this._uop == null) {
                    this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*"));
                }
                MatrixBlock t1 = blkIn.unaryOperations(this._uop, new MatrixBlock());
                MatrixBlock t2 = blkIn.slice(0, blkIn.getNumRows() - 1, 1, 1, new MatrixBlock());
                blkOut.reset(1, 2);
                blkOut.quickSetValue(0, 0, t1.quickGetValue(t1.getNumRows() - 1, 0));
                blkOut.quickSetValue(0, 1, t2.prod());
            } else {
                OperationsOnMatrixValues.performAggregateUnary(ixIn, blkIn, ixOut, blkOut, aop, this._blen);
                if (aop.aggOp.existsCorrection()) {
                    blkOut.dropLastRowsOrColumns(aop.aggOp.correction);
                }
            }
            long rlenOut = (long)Math.ceil((double)this._rlen / (double)this._blen);
            long rixOut = (long)Math.ceil((double)ixIn.getRowIndex() / (double)this._blen);
            int rlenBlk = (int)Math.min(rlenOut - (rixOut - 1L) * (long)this._blen, (long)this._blen);
            int clenBlk = blkOut.getNumColumns();
            int posBlk = (int)((ixIn.getRowIndex() - 1L) % (long)this._blen);
            MatrixBlock blkOut2 = new MatrixBlock(rlenBlk, clenBlk, true);
            blkOut2.copy(posBlk, posBlk, 0, clenBlk - 1, blkOut, true);
            ixOut.setIndexes(rixOut, ixOut.getColumnIndex());
            return new Tuple2((Object)ixOut, (Object)blkOut2);
        }
    }
}

