/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.federated;

import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.BufferedReader;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.util.Arrays;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.log4j.Logger;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.ExecutionContextMap;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.federated.FederatedWorkerHandlerException;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionParser;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataFormat;
import org.apache.sysds.runtime.privacy.DMLPrivacyException;
import org.apache.sysds.runtime.privacy.PrivacyMonitor;
import org.apache.sysds.runtime.privacy.propagation.PrivacyPropagator;
import org.apache.sysds.utils.JSONHelper;
import org.apache.sysds.utils.Statistics;
import org.apache.wink.json4j.JSONObject;

public class FederatedWorkerHandler
extends ChannelInboundHandlerAdapter {
    protected static Logger log = Logger.getLogger(FederatedWorkerHandler.class);
    private final ExecutionContextMap _ecm;

    public FederatedWorkerHandler(ExecutionContextMap ecm) {
        this._ecm = ecm;
    }

    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        ctx.writeAndFlush((Object)this.createResponse(msg)).addListener((GenericFutureListener)new CloseListener());
    }

    public FederatedResponse createResponse(Object msg) {
        if (log.isDebugEnabled()) {
            log.debug((Object)("Received: " + msg.getClass().getSimpleName()));
        }
        if (!(msg instanceof FederatedRequest[])) {
            throw new DMLRuntimeException("FederatedWorkerHandler: Received object no instance of 'FederatedRequest[]'.");
        }
        FederatedRequest[] requests = (FederatedRequest[])msg;
        FederatedResponse response = null;
        for (int i = 0; i < requests.length; ++i) {
            FederatedRequest request = requests[i];
            if (log.isInfoEnabled()) {
                log.info((Object)("Executing command " + (i + 1) + "/" + requests.length + ": " + request.getType().name()));
                if (log.isDebugEnabled()) {
                    log.debug((Object)("full command: " + request.toString()));
                }
            }
            PrivacyMonitor.setCheckPrivacy(request.checkPrivacy());
            PrivacyMonitor.clearCheckedConstraints();
            FederatedResponse tmp = this.executeCommand(request);
            FederatedWorkerHandler.conditionalAddCheckedConstraints(request, tmp);
            if (!tmp.isSuccessful()) {
                log.error((Object)("Command " + (Object)((Object)request.getType()) + " failed: " + tmp.getErrorMessage() + "full command: \n" + request.toString()));
                response = response == null || response.isSuccessful() ? tmp : response;
            } else if (request.getType() == FederatedRequest.RequestType.GET_VAR) {
                if (response != null && response.isSuccessful()) {
                    log.error((Object)"Multiple GET_VAR are not supported in single batch of requests.");
                }
                response = tmp;
            } else if (response == null && i == requests.length - 1) {
                response = tmp;
            }
            if (!DMLScript.STATISTICS || request.getType() != FederatedRequest.RequestType.CLEAR || !Statistics.allowWorkerStatistics) continue;
            System.out.println("Federated Worker " + Statistics.display());
            Statistics.reset();
        }
        return response;
    }

    private static void conditionalAddCheckedConstraints(FederatedRequest request, FederatedResponse response) {
        if (request.checkPrivacy()) {
            response.setCheckedConstraints(PrivacyMonitor.getCheckedConstraints());
        }
    }

    private FederatedResponse executeCommand(FederatedRequest request) {
        FederatedRequest.RequestType method = request.getType();
        try {
            switch (method) {
                case READ_VAR: {
                    return this.readData(request);
                }
                case PUT_VAR: {
                    return this.putVariable(request);
                }
                case GET_VAR: {
                    return this.getVariable(request);
                }
                case EXEC_INST: {
                    return this.execInstruction(request);
                }
                case EXEC_UDF: {
                    return this.execUDF(request);
                }
                case CLEAR: {
                    return this.execClear();
                }
            }
            String message = String.format("Method %s is not supported.", new Object[]{method});
            return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException(message));
        }
        catch (FederatedWorkerHandlerException | DMLPrivacyException ex) {
            return new FederatedResponse(FederatedResponse.ResponseType.ERROR, ex);
        }
        catch (Exception ex) {
            return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Exception of type " + ex.getClass() + " thrown when processing request", ex));
        }
    }

    private FederatedResponse readData(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 2);
        String filename = (String)request.getParam(0);
        Types.DataType dt = Types.DataType.valueOf((String)request.getParam(1));
        return this.readData(filename, dt, request.getID(), request.getTID());
    }

    /*
     * Enabled aggressive block sorting
     * Enabled unnecessary exception pruning
     * Enabled aggressive exception aggregation
     */
    private FederatedResponse readData(String filename, Types.DataType dataType, long id, long tid) {
        CacheableData cd;
        MatrixCharacteristics mc = new MatrixCharacteristics();
        mc.setBlocksize(ConfigurationManager.getBlocksize());
        switch (dataType) {
            case MATRIX: {
                cd = new MatrixObject(Types.ValueType.FP64, filename);
                break;
            }
            case FRAME: {
                cd = new FrameObject(filename);
                break;
            }
            default: {
                return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Could not recognize datatype"));
            }
        }
        Types.FileFormat fmt = null;
        boolean header = false;
        try {
            String mtdname = DataExpression.getMTDFileName(filename);
            Path path = new Path(mtdname);
            FileSystem fs = IOUtilFunctions.getFileSystem(mtdname);
            try (BufferedReader br = new BufferedReader(new InputStreamReader((InputStream)fs.open(path)));){
                JSONObject mtd = JSONHelper.parse(br);
                if (mtd == null) {
                    FederatedResponse federatedResponse = new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Could not parse metadata file"));
                    return federatedResponse;
                }
                mc.setRows(mtd.getLong("rows"));
                mc.setCols(mtd.getLong("cols"));
                if (mtd.containsKey("nnz")) {
                    mc.setNonZeros(mtd.getLong("nnz"));
                }
                if (mtd.has("header")) {
                    header = mtd.getBoolean("header");
                }
                cd = (CacheableData)PrivacyPropagator.parseAndSetPrivacyConstraint(cd, mtd);
                fmt = Types.FileFormat.safeValueOf(mtd.getString("format"));
            }
        }
        catch (Exception ex) {
            throw new DMLRuntimeException(ex);
        }
        cd.setMetaData(new MetaDataFormat(mc, fmt));
        cd.setFileFormatProperties(new FileFormatPropertiesCSV(header, ",", false));
        cd.enableCleanup(false);
        this._ecm.get(tid).setVariable(String.valueOf(id), cd);
        if (dataType != Types.DataType.FRAME) return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, id);
        FrameObject frameObject = (FrameObject)cd;
        frameObject.acquireRead();
        frameObject.refreshMetaData();
        frameObject.release();
        return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[]{id, frameObject.getSchema()});
    }

    private FederatedResponse putVariable(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 1);
        String varname = String.valueOf(request.getID());
        ExecutionContext ec = this._ecm.get(request.getTID());
        if (ec.containsVariable(varname)) {
            return new FederatedResponse(FederatedResponse.ResponseType.ERROR, "Variable " + request.getID() + " already existing.");
        }
        Data data = null;
        if (request.getParam(0) instanceof CacheBlock) {
            data = ExecutionContext.createCacheableData((CacheBlock)request.getParam(0));
        } else if (request.getParam(0) instanceof ScalarObject) {
            data = (ScalarObject)request.getParam(0);
        }
        ec.setVariable(varname, data);
        return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
    }

    private FederatedResponse getVariable(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 0);
        ExecutionContext ec = this._ecm.get(request.getTID());
        if (!ec.containsVariable(String.valueOf(request.getID()))) {
            return new FederatedResponse(FederatedResponse.ResponseType.ERROR, "Variable " + request.getID() + " does not exist at federated worker.");
        }
        Data dataObject = ec.getVariable(String.valueOf(request.getID()));
        dataObject = PrivacyMonitor.handlePrivacy(dataObject);
        switch (dataObject.getDataType()) {
            case MATRIX: 
            case FRAME: 
            case TENSOR: {
                return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ((CacheableData)dataObject).acquireReadAndRelease());
            }
            case LIST: {
                return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ((ListObject)dataObject).getData());
            }
            case SCALAR: {
                return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, dataObject);
            }
        }
        return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Unsupported return datatype " + dataObject.getDataType().name()));
    }

    private FederatedResponse execInstruction(FederatedRequest request) {
        ExecutionContext ec = this._ecm.get(request.getTID());
        BasicProgramBlock pb = new BasicProgramBlock(null);
        pb.getInstructions().clear();
        Instruction receivedInstruction = InstructionParser.parseSingleInstruction((String)request.getParam(0));
        pb.getInstructions().add(receivedInstruction);
        try {
            pb.execute(ec);
        }
        catch (Exception ex) {
            return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Exception of type " + ex.getClass() + " thrown when processing EXEC_INST request", ex));
        }
        return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
    }

    private FederatedResponse execUDF(FederatedRequest request) {
        FederatedWorkerHandler.checkNumParams(request.getNumParams(), 1);
        ExecutionContext ec = this._ecm.get(request.getTID());
        FederatedUDF udf = (FederatedUDF)request.getParam(0);
        Data[] inputs = (Data[])Arrays.stream(udf.getInputIDs()).mapToObj(id -> ec.getVariable(String.valueOf(id))).map(PrivacyMonitor::handlePrivacy).toArray(Data[]::new);
        try {
            return udf.execute(ec, inputs);
        }
        catch (Exception ex) {
            return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Exception of type " + ex.getClass() + " thrown when processing EXEC_UDF request", ex));
        }
    }

    private FederatedResponse execClear() {
        try {
            this._ecm.clear();
        }
        catch (Exception ex) {
            return new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Exception of type " + ex.getClass() + " thrown when processing CLEAR request", ex));
        }
        return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
    }

    private static void checkNumParams(int actual, int ... expected) {
        if (Arrays.stream(expected).anyMatch(x -> x == actual)) {
            return;
        }
        throw new DMLRuntimeException("FederatedWorkerHandler: Received wrong amount of params: expected=" + Arrays.toString(expected) + ", actual=" + actual);
    }

    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        cause.printStackTrace();
        ctx.close();
    }

    private static class CloseListener
    implements ChannelFutureListener {
        private CloseListener() {
        }

        public void operationComplete(ChannelFuture channelFuture) throws InterruptedException {
            if (!channelFuture.isSuccess()) {
                log.error((Object)"Federated Worker Write failed");
                channelFuture.channel().writeAndFlush((Object)new FederatedResponse(FederatedResponse.ResponseType.ERROR, new FederatedWorkerHandlerException("Error while sending response."))).channel().close().sync();
            } else {
                PrivacyMonitor.clearCheckedConstraints();
                channelFuture.channel().close().sync();
            }
        }
    }
}

