/*
 * Decompiled with CFR 0.152.
 */
package weka.filters.unsupervised.attribute.missingvaluesimputation;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Set;
import java.util.Vector;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.functions.LinearRegression;
import weka.classifiers.functions.Logistic;
import weka.core.Capabilities;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;
import weka.filters.unsupervised.attribute.missingvaluesimputation.AbstractImputation;
import weka.filters.unsupervised.attribute.missingvaluesimputation.Pair;

public class IRMI
extends AbstractImputation
implements TechnicalInformationHandler {
    public static final String NOMINAL_CLASSIFIER = "nominal-classifier";
    public static final String NUMERIC_CLASSIFIER = "numeric-classifier";
    public static final String NUM_EPOCHS = "num-epochs";
    public static final String EPSILON = "epsilon";
    protected Classifier m_nominalClassifier = this.getDefaultNominalClassifier();
    protected Classifier m_numericClassifier = this.getDefaultNumericClassifier();
    protected int m_numEpochs = this.getDefaultNumEpochs();
    protected double m_epsilon = this.getDefaultEpsilon();
    protected Classifier[] m_classifiers;
    protected Instances m_Header;

    @Override
    public String globalInfo() {
        return "Uses the IRMI algorithm as published by Templ et al in 'Iterative stepwise regression imputation using standard and robust methods'.\n\n" + this.getTechnicalInformation();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation result = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        result.setValue(TechnicalInformation.Field.AUTHOR, "Matthias Templ and Alexander Kowarik and Peter Filzmoser");
        result.setValue(TechnicalInformation.Field.TITLE, "Iterative stepwise regression imputation using standard and robust methods");
        result.setValue(TechnicalInformation.Field.JOURNAL, "Computational Statistics & Data Analysis");
        result.setValue(TechnicalInformation.Field.YEAR, "2011");
        result.setValue(TechnicalInformation.Field.VOLUME, "55");
        result.setValue(TechnicalInformation.Field.NUMBER, "10");
        result.setValue(TechnicalInformation.Field.PAGES, "2793-2806");
        result.setValue(TechnicalInformation.Field.ISSN, "0167-9473");
        result.setValue(TechnicalInformation.Field.HTTP, "http://www.statistik.tuwien.ac.at/public/filz/papers/CSDA11TKF.pdf");
        return result;
    }

    @Override
    public Enumeration<Option> listOptions() {
        Vector<Option> result = new Vector<Option>();
        result.addElement(new Option("\t" + this.nominalClassifierTipText() + "\n" + "\t(default: " + this.getDefaultNominalClassifier() + ")", NOMINAL_CLASSIFIER, 1, "-nominal-classifier <classname + options>"));
        result.addElement(new Option("\t" + this.nominalClassifierTipText() + "\n" + "\t(default: " + this.getDefaultNumericClassifier() + ")", NUMERIC_CLASSIFIER, 1, "-numeric-classifier <classname + options>"));
        result.addElement(new Option("\t" + this.numEpochsTipText() + "\n" + "\t(default: " + this.getDefaultNumEpochs() + ")", NUM_EPOCHS, 1, "-num-epochs <int>"));
        result.addElement(new Option("\t" + this.epsilonTipText() + "\n" + "\t(default: " + this.getDefaultEpsilon() + ")", EPSILON, 1, "-epsilon <double>"));
        result.addAll(Collections.list(super.listOptions()));
        return result.elements();
    }

    @Override
    public String[] getOptions() {
        ArrayList<String> result = new ArrayList<String>();
        result.add("-nominal-classifier");
        result.add(Utils.toCommandLine((Object)this.m_nominalClassifier));
        result.add("-numeric-classifier");
        result.add(Utils.toCommandLine((Object)this.m_numericClassifier));
        result.add("-num-epochs");
        result.add("" + this.m_numEpochs);
        result.add("-epsilon");
        result.add("" + this.m_epsilon);
        Collections.addAll(result, super.getOptions());
        return result.toArray(new String[result.size()]);
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        String[] tmpOptions;
        String tmpStr = Utils.getOption((String)NOMINAL_CLASSIFIER, (String[])options);
        if (!tmpStr.isEmpty()) {
            tmpOptions = Utils.splitOptions((String)tmpStr);
            tmpStr = tmpOptions[0];
            tmpOptions[0] = "";
            this.setNominalClassifier((Classifier)Utils.forName(Classifier.class, (String)tmpStr, (String[])tmpOptions));
        } else {
            this.setNominalClassifier(this.getDefaultNominalClassifier());
        }
        tmpStr = Utils.getOption((String)NUMERIC_CLASSIFIER, (String[])options);
        if (!tmpStr.isEmpty()) {
            tmpOptions = Utils.splitOptions((String)tmpStr);
            tmpStr = tmpOptions[0];
            tmpOptions[0] = "";
            this.setNumericClassifier((Classifier)Utils.forName(Classifier.class, (String)tmpStr, (String[])tmpOptions));
        } else {
            this.setNumericClassifier(this.getDefaultNumericClassifier());
        }
        tmpStr = Utils.getOption((String)NUM_EPOCHS, (String[])options);
        if (!tmpStr.isEmpty()) {
            this.setNumEpochs(Integer.parseInt(tmpStr));
        } else {
            this.setNumEpochs(this.getDefaultNumEpochs());
        }
        tmpStr = Utils.getOption((String)EPSILON, (String[])options);
        if (!tmpStr.isEmpty()) {
            this.setEpsilon(Double.parseDouble(tmpStr));
        } else {
            this.setEpsilon(this.getDefaultEpsilon());
        }
        super.setOptions(options);
        Utils.checkForRemainingOptions((String[])options);
    }

    protected Classifier getDefaultNominalClassifier() {
        return new Logistic();
    }

    public void setNominalClassifier(Classifier value) {
        this.m_nominalClassifier = value;
    }

    public Classifier getNominalClassifier() {
        return this.m_nominalClassifier;
    }

    public String nominalClassifierTipText() {
        return "Nominal classifier to use";
    }

    protected Classifier getDefaultNumericClassifier() {
        return new LinearRegression();
    }

    public void setNumericClassifier(Classifier value) {
        this.m_numericClassifier = value;
    }

    public Classifier getNumericClassifier() {
        return this.m_numericClassifier;
    }

    public String numericClassifierTipText() {
        return "Numeric classifier to use";
    }

    protected int getDefaultNumEpochs() {
        return 100;
    }

    public void setNumEpochs(int value) {
        this.m_numEpochs = value;
    }

    public int getNumEpochs() {
        return this.m_numEpochs;
    }

    public String numEpochsTipText() {
        return "Max number of epochs";
    }

    protected double getDefaultEpsilon() {
        return 5.0;
    }

    public void setEpsilon(double value) {
        this.m_epsilon = value;
    }

    public double getEpsilon() {
        return this.m_epsilon;
    }

    public String epsilonTipText() {
        return "Epsilon for early termination";
    }

    @Override
    public Capabilities getCapabilities() {
        Capabilities result = super.getCapabilities();
        result.disableAll();
        result.enable(Capabilities.Capability.NUMERIC_ATTRIBUTES);
        result.enable(Capabilities.Capability.DATE_ATTRIBUTES);
        result.enable(Capabilities.Capability.NOMINAL_ATTRIBUTES);
        result.enable(Capabilities.Capability.MISSING_VALUES);
        result.enableAllClasses();
        result.enable(Capabilities.Capability.MISSING_CLASS_VALUES);
        result.enable(Capabilities.Capability.NO_CLASS);
        return result;
    }

    protected double median(double[] vals) {
        if (vals.length == 0) {
            return 0.0;
        }
        ArrayList<Double> newVals = new ArrayList<Double>(vals.length);
        for (double val : vals) {
            if (Utils.isMissingValue((double)val)) continue;
            newVals.add(val);
        }
        if (newVals.size() == 0) {
            return 0.0;
        }
        Collections.sort(newVals);
        int midPoint = newVals.size() / 2;
        if (newVals.size() % 2 == 0) {
            return ((Double)newVals.get(midPoint) + (Double)newVals.get(midPoint - 1)) / 2.0;
        }
        return (Double)newVals.get(midPoint);
    }

    protected double mode(double[] vals) {
        if (vals.length == 0) {
            return 0.0;
        }
        HashMap<Double, Integer> counts = new HashMap<Double, Integer>();
        for (double num : vals) {
            if (Utils.isMissingValue((double)num)) continue;
            if (counts.get(num) == null) {
                counts.put(num, 1);
                continue;
            }
            counts.put(num, (Integer)counts.get(num) + 1);
        }
        double bestKey = 0.0;
        double bestVal = Double.NEGATIVE_INFINITY;
        Set keys = counts.keySet();
        Iterator iterator = keys.iterator();
        while (iterator.hasNext()) {
            double key = (Double)iterator.next();
            if (!((double)((Integer)counts.get(key)).intValue() > bestVal)) continue;
            bestKey = key;
            bestVal = ((Integer)counts.get(key)).intValue();
        }
        return bestKey;
    }

    @Override
    protected Instances doBuildImputation(Instances data) throws Exception {
        Instances df = new Instances(data);
        int originalClassIndex = df.classIndex();
        for (int i = 0; i < df.numInstances(); ++i) {
            df.get(i).setClassValue(Double.NaN);
        }
        Object[] numMissing = new Pair[df.numAttributes()];
        ArrayList missingIndices = new ArrayList();
        ArrayList observedIndices = new ArrayList();
        for (int l = 0; l < df.numAttributes(); ++l) {
            int missingCount = 0;
            ArrayList<Integer> missing = new ArrayList<Integer>();
            ArrayList<Integer> observed = new ArrayList<Integer>();
            for (int i = 0; i < df.numInstances(); ++i) {
                if (df.get(i).isMissing(l)) {
                    missing.add(i);
                    ++missingCount;
                    continue;
                }
                observed.add(i);
            }
            missingIndices.add(missing);
            observedIndices.add(observed);
            numMissing[l] = new Pair(missingCount, l);
        }
        Arrays.sort(numMissing);
        for (int i = 0; i < df.numAttributes(); ++i) {
            if (i == df.classIndex()) continue;
            double[] vals = df.attributeToDoubleArray(i);
            double colMean = df.attribute(i).isNumeric() ? this.median(vals) : this.mode(vals);
            for (int y = 0; y < vals.length; ++y) {
                if (!Double.isNaN(df.get(y).value(i)) && !Double.isInfinite(df.get(y).value(i))) continue;
                df.get(y).setValue(i, colMean);
            }
        }
        boolean[] isStable = new boolean[df.numAttributes()];
        for (int i = 0; i < df.numAttributes(); ++i) {
            isStable[i] = false;
        }
        this.m_classifiers = new Classifier[df.numAttributes()];
        for (int epochs = 0; epochs < this.getNumEpochs(); ++epochs) {
            for (Object p : numMissing) {
                int l = ((Pair)p).index;
                if (((Pair)p).value == 0 || l == originalClassIndex || isStable[l] || ((ArrayList)observedIndices.get(l)).size() == 0) continue;
                Instances observed = new Instances(data, 0);
                Iterator iterator = ((ArrayList)observedIndices.get(l)).iterator();
                while (iterator.hasNext()) {
                    int i = (Integer)iterator.next();
                    observed.add(df.get(i));
                }
                observed.setClassIndex(l);
                Classifier cls = df.attribute(l).isNominal() ? AbstractClassifier.makeCopy((Classifier)this.m_nominalClassifier) : AbstractClassifier.makeCopy((Classifier)this.m_numericClassifier);
                cls.buildClassifier(observed);
                this.m_classifiers[l] = cls;
                double sumOfSquares = 0.0;
                df.setClassIndex(l);
                Iterator iterator2 = ((ArrayList)missingIndices.get(l)).iterator();
                while (iterator2.hasNext()) {
                    int idx = (Integer)iterator2.next();
                    double currentClassValue = df.get(idx).value(l);
                    double newClassValue = this.m_classifiers[l].classifyInstance(df.get(idx));
                    df.get(idx).setValue(l, newClassValue);
                    sumOfSquares += Math.pow(currentClassValue - newClassValue, 2.0);
                }
                if (!(sumOfSquares < this.m_epsilon)) continue;
                isStable[l] = true;
            }
            boolean allStable = true;
            for (int j = 0; j < isStable.length; ++j) {
                if (j == originalClassIndex || isStable[j]) continue;
                allStable = false;
                break;
            }
            if (allStable) break;
        }
        this.m_Header = new Instances(data, 0);
        return new Instances(data, 0);
    }

    @Override
    protected Instance doImpute(Instance inst) throws Exception {
        Instance result = (Instance)inst.copy();
        result.setDataset(this.m_Header);
        for (int i = 0; i < result.numAttributes(); ++i) {
            if (i == inst.classIndex() || !result.isMissing(i) || this.m_classifiers[i] == null) continue;
            this.m_Header.setClassIndex(i);
            result.setValue(i, this.m_classifiers[i].classifyInstance(result));
        }
        result.setDataset(this.m_OutputFormat);
        return result;
    }
}

