package weka.core.converters;

import adams.core.Index;
import adams.core.Range;
import adams.core.base.BaseString;
import adams.core.io.FileUtils;
import adams.core.management.LocaleHelper;
import adams.core.option.OptionUtils;
import adams.env.Environment;
import gnu.trove.list.array.TIntArrayList;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Enumeration;
import java.util.Iterator;
import java.util.Locale;
import java.util.Vector;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.Utils;
import weka.core.WekaOptionUtils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;

/* loaded from: input_file:weka/core/converters/CNTKSaver.class */
public class CNTKSaver extends AbstractFileSaver implements BatchConverter {
    private static final long serialVersionUID = 4351243795790752863L;
    protected boolean m_Debug = false;
    protected Index m_RowID;
    protected Range[] m_Inputs;
    protected BaseString[] m_InputNames;
    protected boolean m_UseSparseFormat;
    protected boolean m_SuppressMissing;
    protected File m_OutputFile;
    protected Locale m_Locale;

    public CNTKSaver() {
        resetOptions();
    }

    public String globalInfo() {
        return "Writes the Instances to a CNTK text file.\nAutomatically turns a nominal class attribute into CNTK's '1-hot encoding'.";
    }

    public void resetOptions() {
        super.resetOptions();
        this.m_RowID = getDefaultRowID();
        this.m_Inputs = getDefaultInputs();
        this.m_InputNames = getDefaultInputNames();
        this.m_UseSparseFormat = false;
        this.m_SuppressMissing = false;
        this.m_OutputFile = null;
        this.m_Locale = LocaleHelper.getSingleton().getDefault();
    }

    public Enumeration listOptions() {
        Vector vector = new Vector();
        WekaOptionUtils.addOption(vector, debugTipText(), "off", "D");
        WekaOptionUtils.addOption(vector, rowIDTipText(), getDefaultRowID(), "row-id");
        WekaOptionUtils.addOption(vector, inputsTipText(), Utils.arrayToString(getDefaultInputs()), "inputs");
        WekaOptionUtils.addOption(vector, inputNamesTipText(), Utils.arrayToString(getDefaultInputNames()), "input-names");
        WekaOptionUtils.addOption(vector, useSparseFormatTipText(), "no", "use-sparse-format");
        WekaOptionUtils.addOption(vector, suppressMissingTipText(), "no", "suppress-missing");
        return WekaOptionUtils.toEnumeration(vector);
    }

    public void setOptions(String[] strArr) throws Exception {
        setDebug(Utils.getFlag("D", strArr));
        setRowID(WekaOptionUtils.parse(strArr, "row-id", getDefaultRowID()));
        setInputs(WekaOptionUtils.parse(strArr, "inputs", getDefaultInputs()));
        setInputNames((BaseString[]) WekaOptionUtils.parse(strArr, "input-names", getDefaultInputNames()));
        setUseSparseFormat(Utils.getFlag("use-sparse-format", strArr));
        setSuppressMissing(Utils.getFlag("suppress-missing", strArr));
        super.setOptions(strArr);
    }

    public String[] getOptions() {
        ArrayList arrayList = new ArrayList();
        WekaOptionUtils.add(arrayList, "D", getDebug());
        if (!getRowID().isEmpty()) {
            WekaOptionUtils.add(arrayList, "row-id", getRowID());
        }
        WekaOptionUtils.add(arrayList, "inputs", getInputs());
        WekaOptionUtils.add(arrayList, "input-names", getInputNames());
        WekaOptionUtils.add(arrayList, "use-sparse-format", getUseSparseFormat());
        WekaOptionUtils.add(arrayList, "suppress-missing", getSuppressMissing());
        arrayList.addAll(Arrays.asList(super.getOptions()));
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    public String getFileDescription() {
        return "CNTK Text file";
    }

    public String[] getFileExtensions() {
        return new String[]{".txt"};
    }

    public void setDebug(boolean z) {
        this.m_Debug = z;
    }

    public boolean getDebug() {
        return this.m_Debug;
    }

    public String debugTipText() {
        return "Whether to print additional debug information to the console.";
    }

    protected Index getDefaultRowID() {
        return new Index();
    }

    public void setRowID(Index index) {
        this.m_RowID = index;
    }

    public Index getRowID() {
        return this.m_RowID;
    }

    public String rowIDTipText() {
        return "The (optional) attribute to use for the row ID.";
    }

    protected Range[] getDefaultInputs() {
        return new Range[0];
    }

    public void setInputs(Range[] rangeArr) {
        this.m_Inputs = rangeArr;
        this.m_InputNames = (BaseString[]) adams.core.Utils.adjustArray(this.m_InputNames, this.m_Inputs.length, new BaseString());
    }

    public Range[] getInputs() {
        return this.m_Inputs;
    }

    public String inputsTipText() {
        return "The attribute ranges determining the inputs (eg for 'features' and 'class').";
    }

    protected BaseString[] getDefaultInputNames() {
        return new BaseString[0];
    }

    public void setInputNames(BaseString[] baseStringArr) {
        this.m_InputNames = baseStringArr;
        this.m_Inputs = (Range[]) adams.core.Utils.adjustArray(this.m_Inputs, this.m_InputNames.length, new Range());
    }

    public BaseString[] getInputNames() {
        return this.m_InputNames;
    }

    public String inputNamesTipText() {
        return "The names of the inputs (eg 'features' and 'class').";
    }

    public void setUseSparseFormat(boolean z) {
        this.m_UseSparseFormat = z;
    }

    public boolean getUseSparseFormat() {
        return this.m_UseSparseFormat;
    }

    public String useSparseFormatTipText() {
        return "If enabled, sparse format is used instead (ie 'index:value').";
    }

    public void setSuppressMissing(boolean z) {
        this.m_SuppressMissing = z;
    }

    public boolean getSuppressMissing() {
        return this.m_SuppressMissing;
    }

    public String suppressMissingTipText() {
        return "If enabled, groups that contain at least one missing value get suppressed completely.";
    }

    public Capabilities getCapabilities() {
        Capabilities capabilities = super.getCapabilities();
        capabilities.disableAll();
        capabilities.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        capabilities.enable(Capabilities.Capability.NOMINAL_CLASS);
        capabilities.enable(Capabilities.Capability.NUMERIC_CLASS);
        capabilities.enable(Capabilities.Capability.NO_CLASS);
        return capabilities;
    }

    public void setDestination(File file) throws IOException {
        this.m_OutputFile = file;
    }

    public void setDestination(OutputStream outputStream) throws IOException {
        throw new IOException("Writing to an outputstream not supported");
    }

    protected synchronized String format(double d) {
        return Double.isNaN(d) ? "NaN" : Double.isInfinite(d) ? d < 0.0d ? "-Infinity" : "+Infinity" : adams.core.Utils.doubleToString(d, 12, this.m_Locale);
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void writeBatch() throws IOException {
        if (getInstances() == null) {
            throw new IOException("No instances to save!");
        }
        Instances instances = getInstances();
        if (this.m_OutputFile == null) {
            throw new IOException("No output file set!");
        }
        if (getRetrieval() == 2) {
            throw new IOException("Batch and incremental saving cannot be mixed.");
        }
        setRetrieval(1);
        setWriteMode(0);
        if (this.m_Inputs.length == 0) {
            throw new IllegalStateException("No input ranges defined!");
        }
        this.m_InputNames = (BaseString[]) adams.core.Utils.adjustArray(this.m_InputNames, this.m_Inputs.length, new BaseString());
        String[] strArr = new String[this.m_InputNames.length];
        for (int i = 0; i < this.m_InputNames.length; i++) {
            strArr[i] = this.m_InputNames[i].getValue();
            if (strArr[i].isEmpty()) {
                strArr[i] = "input-" + (i + 1);
            }
        }
        this.m_RowID.setMax(instances.numAttributes());
        int intIndex = this.m_RowID.getIntIndex();
        if (getDebug()) {
            System.out.println("row ID col (0-based, ignored if -1): " + intIndex);
        }
        int[] iArr = new int[this.m_Inputs.length];
        if (getDebug()) {
            System.out.println("# of inputs: " + this.m_Inputs.length);
        }
        for (int i2 = 0; i2 < this.m_Inputs.length; i2++) {
            this.m_Inputs[i2].setMax(instances.numAttributes());
            iArr[i2] = this.m_Inputs[i2].getIntIndices();
            if (getDebug()) {
                System.out.println("input " + (i2 + 1) + " (0-based): " + adams.core.Utils.arrayToString(iArr[i2]));
            }
        }
        int classIndex = instances.classIndex();
        if (classIndex > -1 && instances.classAttribute().isNominal()) {
            int numValues = instances.classAttribute().numValues();
            Instances instances2 = new Instances(instances);
            instances2.setClassIndex(-1);
            NominalToBinary nominalToBinary = new NominalToBinary();
            nominalToBinary.setAttributeIndices("" + (classIndex + 1));
            nominalToBinary.setTransformAllValues(true);
            try {
                nominalToBinary.setInputFormat(instances2);
                instances = Filter.useFilter(instances2, nominalToBinary);
                TIntArrayList tIntArrayList = new TIntArrayList();
                for (int i3 = 0; i3 < iArr.length; i3++) {
                    for (int i4 = 0; i4 < iArr[i3].length; i4++) {
                        if (iArr[i3][i4] > classIndex) {
                            int[] iArr2 = iArr[i3];
                            int i5 = i4;
                            iArr2[i5] = iArr2[i5] + (numValues - 1);
                        } else if (iArr[i3][i4] == classIndex && !tIntArrayList.contains(iArr[i3][i4])) {
                            tIntArrayList.add(i3);
                        }
                    }
                }
                if (getDebug()) {
                    System.out.println("Arrays affected by binarization: " + tIntArrayList);
                }
                for (int i6 = 0; i6 < tIntArrayList.size(); i6++) {
                    TIntArrayList tIntArrayList2 = new TIntArrayList();
                    for (int i7 = 0; i7 < iArr[tIntArrayList.get(i6)].length; i7++) {
                        tIntArrayList2.add(iArr[tIntArrayList.get(i6)][i7]);
                        if (iArr[tIntArrayList.get(i6)][i7] == classIndex) {
                            for (int i8 = 1; i8 < numValues; i8++) {
                                tIntArrayList2.add(classIndex + i8);
                            }
                        }
                    }
                    if (getDebug()) {
                        System.out.println("Affected array #" + tIntArrayList.get(i6) + " (old): " + Utils.arrayToString(iArr[tIntArrayList.get(i6)]));
                    }
                    iArr[tIntArrayList.get(i6)] = tIntArrayList2.toArray();
                    if (getDebug()) {
                        System.out.println("Affected array #" + tIntArrayList.get(i6) + " (fixed): " + Utils.arrayToString(iArr[tIntArrayList.get(i6)]));
                    }
                }
            } catch (Exception e) {
                throw new IOException("Failed to binarize class attribute, using: " + OptionUtils.getCommandLine(nominalToBinary), e);
            }
        }
        TIntArrayList tIntArrayList3 = new TIntArrayList();
        if (!this.m_SuppressMissing) {
            for (int i9 = 0; i9 < this.m_Inputs.length; i9++) {
                tIntArrayList3.add(i9);
            }
        }
        FileWriter fileWriter = new FileWriter(this.m_OutputFile);
        BufferedWriter bufferedWriter = new BufferedWriter(fileWriter);
        int i10 = 0;
        Iterator it = instances.iterator();
        while (it.hasNext()) {
            Instance instance = (Instance) it.next();
            i10++;
            if (this.m_SuppressMissing) {
                tIntArrayList3.clear();
                for (int i11 = 0; i11 < iArr.length; i11++) {
                    boolean z = false;
                    int i12 = 0;
                    while (true) {
                        if (i12 >= iArr[i11].length) {
                            break;
                        }
                        if (instance.isMissing(iArr[i11][i12])) {
                            z = true;
                            break;
                        }
                        i12++;
                    }
                    if (!z) {
                        tIntArrayList3.add(i11);
                    }
                }
            }
            if (this.m_SuppressMissing && getDebug()) {
                System.out.println("Row #" + i10 + " / inputs to output: " + tIntArrayList3);
            }
            if (tIntArrayList3.size() != 0) {
                if (intIndex > -1) {
                    try {
                        if (!instance.isMissing(intIndex)) {
                            bufferedWriter.write(format(instance.value(intIndex)));
                            bufferedWriter.write(" ");
                        }
                    } catch (Exception e2) {
                        System.err.println("Failed to write data: " + instance);
                        e2.printStackTrace();
                    }
                }
                for (int i13 = 0; i13 < tIntArrayList3.size(); i13++) {
                    int i14 = tIntArrayList3.get(i13);
                    bufferedWriter.write("|");
                    bufferedWriter.write(strArr[i14]);
                    bufferedWriter.write(" ");
                    for (int i15 = 0; i15 < iArr[i14].length; i15++) {
                        Double valueOf = instance.isMissing(iArr[i14][i15]) ? null : Double.valueOf(instance.value(iArr[i14][i15]));
                        if (this.m_UseSparseFormat) {
                            if (valueOf == null || valueOf.doubleValue() != 0.0d) {
                                bufferedWriter.write("" + ((int) iArr[i14][i15]));
                                bufferedWriter.write(":");
                            }
                        }
                        if (valueOf == null) {
                            bufferedWriter.write("?");
                        } else {
                            bufferedWriter.write(format(valueOf.doubleValue()));
                        }
                        bufferedWriter.write(" ");
                    }
                }
                bufferedWriter.write("\n");
                bufferedWriter.flush();
            }
        }
        FileUtils.closeQuietly(bufferedWriter);
        FileUtils.closeQuietly(fileWriter);
    }

    public String getRevision() {
        return RevisionUtils.extract("$Revision$");
    }

    public static void main(String[] strArr) {
        Environment.setEnvironmentClass(Environment.class);
        runFileSaver(new CNTKSaver(), strArr);
    }
}
