/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamServer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;

public abstract class PSWorker
implements Serializable {
    private static final long serialVersionUID = -3510485051178200118L;
    protected ExecutorService _tpool = LazyWriteBuffer.getUtilThreadPool();
    protected int _workerID;
    protected int _epochs;
    protected long _batchSize;
    protected ExecutionContext _ec;
    protected ParamServer _ps;
    protected DataIdentifier _output;
    protected FunctionCallCPInstruction _inst;
    protected MatrixObject _features;
    protected MatrixObject _labels;
    protected String _updFunc;
    protected Statement.PSFrequency _freq;
    protected int _nbatches;
    protected boolean _modelAvg;

    protected PSWorker() {
    }

    protected PSWorker(int workerID, String updFunc, Statement.PSFrequency freq, int epochs, long batchSize, ExecutionContext ec, ParamServer ps, int nbatches, boolean modelAvg) {
        this._workerID = workerID;
        this._updFunc = updFunc;
        this._freq = freq;
        this._epochs = epochs;
        this._batchSize = batchSize;
        this._ec = ec;
        this._ps = ps;
        this._nbatches = nbatches;
        this._modelAvg = modelAvg;
        this.setupUpdateFunction(updFunc, ec);
    }

    protected void setupUpdateFunction(String updFunc, ExecutionContext ec) {
        String[] cfn = DMLProgram.splitFunctionKey(updFunc);
        String ns = cfn[0];
        String fname = cfn[1];
        boolean opt = !ec.getProgram().containsFunctionProgramBlock(ns, fname, false);
        FunctionProgramBlock func = ec.getProgram().getFunctionProgramBlock(ns, fname, opt);
        ArrayList<DataIdentifier> inputs = func.getInputParams();
        ArrayList<DataIdentifier> outputs = func.getOutputParams();
        CPOperand[] boundInputs = (CPOperand[])inputs.stream().map(input -> new CPOperand(input.getName(), input.getValueType(), input.getDataType())).toArray(CPOperand[]::new);
        ArrayList outputNames = outputs.stream().map(DataIdentifier::getName).collect(Collectors.toCollection(ArrayList::new));
        this._inst = new FunctionCallCPInstruction(ns, fname, opt, boundInputs, func.getInputParamNames(), outputNames, "update function");
        this.checkInput(false, inputs, Types.DataType.MATRIX, "features");
        this.checkInput(false, inputs, Types.DataType.MATRIX, "labels");
        this.checkInput(false, inputs, Types.DataType.LIST, "model");
        this.checkInput(true, inputs, Types.DataType.LIST, "hyperparams");
        if (outputs.size() != 1) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the gradients.", updFunc));
        }
        if (outputs.get(0).getDataType() != Types.DataType.LIST) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", updFunc));
        }
        this._output = outputs.get(0);
    }

    private void checkInput(boolean optional, ArrayList<DataIdentifier> inputs, Types.DataType dt, String pname) {
        if (optional && inputs.stream().noneMatch(input -> pname.equals(input.getName()))) {
            return;
        }
        if (inputs.stream().filter(input -> input.getDataType() == dt && pname.equals(input.getName())).count() != 1L) {
            throw new DMLRuntimeException(String.format("The '%s' function should provide an input of '%s' type named '%s'.", new Object[]{this._updFunc, dt, pname}));
        }
    }

    public void setFeatures(MatrixObject features) {
        this._features = features;
    }

    public void setLabels(MatrixObject labels) {
        this._labels = labels;
    }

    public MatrixObject getFeatures() {
        return this._features;
    }

    public MatrixObject getLabels() {
        return this._labels;
    }

    public abstract String getWorkerName();

    protected abstract void incWorkerNumber();

    protected abstract void accLocalModelUpdateTime(Timing var1);

    protected abstract void accBatchIndexingTime(Timing var1);

    protected abstract void accGradientComputeTime(Timing var1);

    protected void accNumEpochs(int n) {
    }

    protected void accNumBatches(int n) {
    }
}

