/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.runtime.instructions.spark;

import java.io.Serializable;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.function.Function;
import org.apache.sysml.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysml.runtime.instructions.InstructionUtils;
import org.apache.sysml.runtime.instructions.cp.CPOperand;
import org.apache.sysml.runtime.instructions.spark.ComputationSPInstruction;
import org.apache.sysml.runtime.instructions.spark.SPInstruction;
import org.apache.sysml.runtime.matrix.MatrixCharacteristics;
import org.apache.sysml.runtime.matrix.data.MatrixBlock;
import org.apache.sysml.runtime.matrix.data.MatrixIndexes;
import org.apache.sysml.runtime.matrix.operators.TernaryOperator;
import scala.Tuple2;

public class TernarySPInstruction
extends ComputationSPInstruction {
    private TernarySPInstruction(TernaryOperator op, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String str) {
        super(SPInstruction.SPType.Ternary, op, in1, in2, in3, out, opcode, str);
    }

    public static TernarySPInstruction parseInstruction(String str) {
        String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
        String opcode = parts[0];
        CPOperand operand1 = new CPOperand(parts[1]);
        CPOperand operand2 = new CPOperand(parts[2]);
        CPOperand operand3 = new CPOperand(parts[3]);
        CPOperand outOperand = new CPOperand(parts[4]);
        TernaryOperator op = InstructionUtils.parseTernaryOperator(opcode);
        return new TernarySPInstruction(op, operand1, operand2, operand3, outOperand, opcode, str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        SparkExecutionContext sec = (SparkExecutionContext)ec;
        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = !this.input1.isMatrix() ? null : sec.getBinaryBlockRDDHandleForVariable(this.input1.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in2 = !this.input2.isMatrix() ? null : sec.getBinaryBlockRDDHandleForVariable(this.input2.getName());
        JavaPairRDD<MatrixIndexes, MatrixBlock> in3 = !this.input3.isMatrix() ? null : sec.getBinaryBlockRDDHandleForVariable(this.input3.getName());
        MatrixBlock m1 = this.input1.isMatrix() ? null : new MatrixBlock(ec.getScalarInput(this.input1).getDoubleValue());
        MatrixBlock m2 = this.input2.isMatrix() ? null : new MatrixBlock(ec.getScalarInput(this.input2).getDoubleValue());
        MatrixBlock m3 = this.input3.isMatrix() ? null : new MatrixBlock(ec.getScalarInput(this.input3).getDoubleValue());
        TernaryOperator op = (TernaryOperator)this._optr;
        JavaPairRDD out = null;
        out = this.input1.isMatrix() && !this.input2.isMatrix() && !this.input3.isMatrix() ? in1.mapValues((Function)new TernaryFunctionMSS(op, m1, m2, m3)) : (!this.input1.isMatrix() && this.input2.isMatrix() && !this.input3.isMatrix() ? in2.mapValues((Function)new TernaryFunctionSMS(op, m1, m2, m3)) : (!this.input1.isMatrix() && !this.input2.isMatrix() && this.input3.isMatrix() ? in3.mapValues((Function)new TernaryFunctionSSM(op, m1, m2, m3)) : (this.input1.isMatrix() && this.input2.isMatrix() && !this.input3.isMatrix() ? in1.join(in2).mapValues((Function)new TernaryFunctionMMS(op, m1, m2, m3)) : (this.input1.isMatrix() && !this.input2.isMatrix() && this.input3.isMatrix() ? in1.join(in3).mapValues((Function)new TernaryFunctionMSM(op, m1, m2, m3)) : (!this.input1.isMatrix() && this.input2.isMatrix() && this.input3.isMatrix() ? in2.join(in3).mapValues((Function)new TernaryFunctionSMM(op, m1, m2, m3)) : in1.join(in2).join(in3).mapValues((Function)new TernaryFunctionMMM(op, m1, m2, m3)))))));
        this.updateTernaryOutputMatrixCharacteristics(sec);
        sec.setRDDHandleForVariable(this.output.getName(), out);
        if (this.input1.isMatrix()) {
            sec.addLineageRDD(this.output.getName(), this.input1.getName());
        }
        if (this.input2.isMatrix()) {
            sec.addLineageRDD(this.output.getName(), this.input2.getName());
        }
        if (this.input3.isMatrix()) {
            sec.addLineageRDD(this.output.getName(), this.input3.getName());
        }
    }

    protected void updateTernaryOutputMatrixCharacteristics(SparkExecutionContext sec) {
        MatrixCharacteristics mcOut = sec.getMatrixCharacteristics(this.output.getName());
        for (CPOperand input : new CPOperand[]{this.input1, this.input2, this.input3}) {
            MatrixCharacteristics mc;
            if (!input.isMatrix() || !(mc = sec.getMatrixCharacteristics(input.getName())).dimsKnown()) continue;
            mcOut.set(mc);
        }
    }

    private static class TernaryFunctionMMM
    extends TernaryFunction
    implements Function<Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 1L;

        public TernaryFunctionMMM(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) {
            super(op, m1, m2, m3);
        }

        public MatrixBlock call(Tuple2<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> v1) throws Exception {
            return ((MatrixBlock)((Tuple2)v1._1())._1()).ternaryOperations(this._op, (MatrixBlock)((Tuple2)v1._1())._2(), (MatrixBlock)v1._2(), new MatrixBlock());
        }
    }

    private static class TernaryFunctionSMM
    extends TernaryFunction
    implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 1L;

        public TernaryFunctionSMM(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) {
            super(op, m1, m2, m3);
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> v1) throws Exception {
            return this._m1.ternaryOperations(this._op, (MatrixBlock)v1._1(), (MatrixBlock)v1._2(), new MatrixBlock());
        }
    }

    private static class TernaryFunctionMSM
    extends TernaryFunction
    implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 1L;

        public TernaryFunctionMSM(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) {
            super(op, m1, m2, m3);
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> v1) throws Exception {
            return ((MatrixBlock)v1._1()).ternaryOperations(this._op, this._m2, (MatrixBlock)v1._2(), new MatrixBlock());
        }
    }

    private static class TernaryFunctionMMS
    extends TernaryFunction
    implements Function<Tuple2<MatrixBlock, MatrixBlock>, MatrixBlock> {
        private static final long serialVersionUID = 1L;

        public TernaryFunctionMMS(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) {
            super(op, m1, m2, m3);
        }

        public MatrixBlock call(Tuple2<MatrixBlock, MatrixBlock> v1) throws Exception {
            return ((MatrixBlock)v1._1()).ternaryOperations(this._op, (MatrixBlock)v1._2(), this._m3, new MatrixBlock());
        }
    }

    private static class TernaryFunctionSSM
    extends TernaryFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1L;

        public TernaryFunctionSSM(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) {
            super(op, m1, m2, m3);
        }

        public MatrixBlock call(MatrixBlock v1) throws Exception {
            return this._m1.ternaryOperations(this._op, this._m2, v1, new MatrixBlock());
        }
    }

    private static class TernaryFunctionSMS
    extends TernaryFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1L;

        public TernaryFunctionSMS(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) {
            super(op, m1, m2, m3);
        }

        public MatrixBlock call(MatrixBlock v1) throws Exception {
            return this._m1.ternaryOperations(this._op, v1, this._m3, new MatrixBlock());
        }
    }

    private static class TernaryFunctionMSS
    extends TernaryFunction
    implements Function<MatrixBlock, MatrixBlock> {
        private static final long serialVersionUID = 1L;

        public TernaryFunctionMSS(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) {
            super(op, m1, m2, m3);
        }

        public MatrixBlock call(MatrixBlock v1) throws Exception {
            return v1.ternaryOperations(this._op, this._m2, this._m3, new MatrixBlock());
        }
    }

    private static abstract class TernaryFunction
    implements Serializable {
        private static final long serialVersionUID = 8345737737972434426L;
        protected final TernaryOperator _op;
        protected final MatrixBlock _m1;
        protected final MatrixBlock _m2;
        protected final MatrixBlock _m3;

        public TernaryFunction(TernaryOperator op, MatrixBlock m1, MatrixBlock m2, MatrixBlock m3) {
            this._op = op;
            this._m1 = m1;
            this._m2 = m2;
            this._m3 = m3;
        }
    }
}

