/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.SparkPSBody;
import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
import org.apache.sysds.runtime.controlprogram.parfor.RemoteParForUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.ProgramConverter;
import scala.Tuple2;

public class SparkPSWorker
extends LocalPSWorker
implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> {
    private static final long serialVersionUID = -8674739573419648732L;
    private final String _program;
    private final HashMap<String, byte[]> _clsMap;
    private final SparkConf _conf;
    private final int _port;
    private final String _aggFunc;
    private final LongAccumulator _aSetup;
    private final LongAccumulator _aWorker;
    private final LongAccumulator _aUpdate;
    private final LongAccumulator _aIndex;
    private final LongAccumulator _aGrad;
    private final LongAccumulator _aRPC;
    private final LongAccumulator _nBatches;
    private final LongAccumulator _nEpochs;

    public SparkPSWorker(String updFunc, String aggFunc, Statement.PSFrequency freq, int epochs, long batchSize, String program, HashMap<String, byte[]> clsMap, SparkConf conf, int port, LongAccumulator aSetup, LongAccumulator aWorker, LongAccumulator aUpdate, LongAccumulator aIndex, LongAccumulator aGrad, LongAccumulator aRPC, LongAccumulator aBatches, LongAccumulator aEpochs) {
        this._updFunc = updFunc;
        this._aggFunc = aggFunc;
        this._freq = freq;
        this._epochs = epochs;
        this._batchSize = batchSize;
        this._program = program;
        this._clsMap = clsMap;
        this._conf = conf;
        this._port = port;
        this._aSetup = aSetup;
        this._aWorker = aWorker;
        this._aUpdate = aUpdate;
        this._aIndex = aIndex;
        this._aGrad = aGrad;
        this._aRPC = aRPC;
        this._nBatches = aBatches;
        this._nEpochs = aEpochs;
    }

    @Override
    public String getWorkerName() {
        return String.format("Spark worker_%d", this._workerID);
    }

    public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws Exception {
        Timing tSetup = new Timing(true);
        this.configureWorker(input);
        this.accSetupTime(tSetup);
        this.call();
    }

    private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> input) throws IOException {
        this._workerID = (Integer)input._1;
        for (Map.Entry<String, byte[]> e : this._clsMap.entrySet()) {
            CodegenUtils.getClassSync(e.getKey(), e.getValue());
        }
        SparkPSBody body = ProgramConverter.parseSparkPSBody(this._program, this._workerID);
        this._ec = body.getEc();
        RemoteParForUtils.setupBufferPool(this._workerID);
        this._ps = PSRpcFactory.createSparkPSProxy(this._conf, this._port, this._aRPC);
        this.setupUpdateFunction(this._updFunc, this._ec);
        this._ps.setupAggFunc(this._ec, this._aggFunc);
        this.setFeatures(ParamservUtils.newMatrixObject((MatrixBlock)((Tuple2)input._2)._1, false));
        this.setLabels(ParamservUtils.newMatrixObject((MatrixBlock)((Tuple2)input._2)._2, false));
    }

    @Override
    protected void incWorkerNumber() {
        this._aWorker.add(1L);
    }

    @Override
    protected void accLocalModelUpdateTime(Timing time) {
        if (time != null) {
            this._aUpdate.add((long)time.stop());
        }
    }

    @Override
    protected void accBatchIndexingTime(Timing time) {
        if (time != null) {
            this._aIndex.add((long)time.stop());
        }
    }

    @Override
    protected void accGradientComputeTime(Timing time) {
        if (time != null) {
            this._aGrad.add((long)time.stop());
        }
    }

    @Override
    protected void accNumEpochs(int n) {
        this._nEpochs.add((long)n);
    }

    @Override
    protected void accNumBatches(int n) {
        this._nBatches.add((long)n);
    }

    private void accSetupTime(Timing time) {
        if (time != null) {
            this._aSetup.add((long)time.stop());
        }
    }
}

