/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.functionobjects;

import org.apache.sysds.runtime.functionobjects.IndexFunction;
import org.apache.sysds.runtime.matrix.data.MatrixIndexes;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.meta.DataCharacteristics;

public class DiagIndex
extends IndexFunction {
    private static final long serialVersionUID = -5294771266108903886L;
    private final boolean diagV2M;

    private DiagIndex(boolean v2m) {
        this.diagV2M = v2m;
    }

    public static DiagIndex getDiagIndexFnObject() {
        return DiagIndex.getDiagIndexFnObject(true);
    }

    public static DiagIndex getDiagIndexFnObject(boolean v2m) {
        return new DiagIndex(v2m);
    }

    @Override
    public void execute(MatrixIndexes in, MatrixIndexes out) {
        out.setIndexes(in.getRowIndex(), this.diagV2M ? in.getRowIndex() : 1L);
    }

    @Override
    public void execute(MatrixValue.CellIndex in, MatrixValue.CellIndex out) {
        out.set(in.row, in.row);
    }

    @Override
    public boolean computeDimension(int row, int col, MatrixValue.CellIndex retDim) {
        if (col == 1) {
            retDim.set(row, row);
        } else {
            retDim.set(row, 1);
        }
        return false;
    }

    @Override
    public boolean computeDimension(DataCharacteristics in, DataCharacteristics out) {
        if (in.getCols() == 1L) {
            out.set(in.getRows(), in.getRows(), in.getBlocksize(), in.getBlocksize());
        } else {
            out.set(in.getRows(), 1L, in.getBlocksize(), in.getBlocksize());
        }
        return false;
    }
}

