package adams.data.weka.evaluator;

import adams.core.Utils;
import adams.core.logging.Logger;
import adams.data.weka.evaluator.AbstractCrossvalidatedInstanceEvaluator;
import java.util.Collections;
import java.util.Vector;
import java.util.logging.Level;
import weka.classifiers.Classifier;
import weka.classifiers.IntervalEstimator;
import weka.classifiers.functions.GaussianProcessesNoWeights;
import weka.core.Instance;
import weka.core.Instances;

/* loaded from: input_file:adams/data/weka/evaluator/IntervalEstimatorBased.class */
public class IntervalEstimatorBased extends AbstractCrossvalidatedInstanceEvaluator<SortedInterval> {
    private static final long serialVersionUID = -7760097633698319552L;
    protected Classifier m_Classifier;
    protected double m_ConfidenceLevel;
    protected boolean m_RelativeWidths;
    protected double m_MaxWidth;
    protected double m_MinWidth;

    /* loaded from: input_file:adams/data/weka/evaluator/IntervalEstimatorBased$SortedInterval.class */
    public static class SortedInterval extends AbstractCrossvalidatedInstanceEvaluator.EvaluationContainer {
        protected double[][] m_Intervals;
        protected double m_AverageWidth;
        protected boolean m_RelativeWidths;

        public SortedInterval(Instance instance, double[][] dArr, boolean z) {
            super(instance);
            this.m_Intervals = (double[][]) dArr.clone();
            this.m_AverageWidth = IntervalEstimatorBased.calcAverageWidth(this.m_Intervals);
            this.m_RelativeWidths = z;
            if (this.m_RelativeWidths) {
                if (this.m_Instance.classValue() == 0.0d) {
                    this.m_AverageWidth = Double.MAX_VALUE;
                } else {
                    this.m_AverageWidth /= this.m_Instance.classValue();
                }
            }
        }

        public double[][] getIntervals() {
            return this.m_Intervals;
        }

        public double getAverageWidth() {
            return this.m_AverageWidth;
        }

        @Override // adams.data.weka.evaluator.AbstractCrossvalidatedInstanceEvaluator.EvaluationContainer, java.lang.Comparable
        public int compareTo(Object obj) {
            if (obj == null) {
                return 1;
            }
            SortedInterval sortedInterval = (SortedInterval) obj;
            int compareTo = new Integer(getIntervals().length).compareTo(new Integer(sortedInterval.getIntervals().length));
            if (compareTo == 0) {
                for (int i = 0; i < this.m_Intervals.length; i++) {
                    compareTo = new Double(IntervalEstimatorBased.calcWidth(getIntervals()[i])).compareTo(new Double(IntervalEstimatorBased.calcWidth(sortedInterval.getIntervals()[i])));
                    if (compareTo != 0) {
                        break;
                    }
                }
            }
            return compareTo;
        }

        @Override // adams.data.weka.evaluator.AbstractCrossvalidatedInstanceEvaluator.EvaluationContainer
        public String toString() {
            return "intervals=" + Utils.arrayToString(this.m_Intervals) + ", avg width=" + this.m_AverageWidth;
        }
    }

    public String globalInfo() {
        return "Uses a classifier that produces confidence intervals. ???";
    }

    @Override // adams.data.weka.evaluator.AbstractCrossvalidatedInstanceEvaluator, adams.data.weka.evaluator.AbstractDatasetInstanceEvaluator
    public void defineOptions() {
        super.defineOptions();
        this.m_OptionManager.add("classifier", "classifier", new GaussianProcessesNoWeights());
        this.m_OptionManager.add("level", "confidenceLevel", Double.valueOf(0.95d));
        this.m_OptionManager.add("relative", "relativeWidths", false);
    }

    public void setClassifier(Classifier classifier) {
        if (!(classifier instanceof IntervalEstimator)) {
            getLogger().severe("Classifier must implement " + IntervalEstimator.class.getName());
        } else {
            this.m_Classifier = classifier;
            reset();
        }
    }

    public Classifier getClassifier() {
        return this.m_Classifier;
    }

    public String classifierTipText() {
        return "The classifier to use (must implement " + IntervalEstimator.class.getName() + ").";
    }

    public void setConfidenceLevel(double d) {
        this.m_ConfidenceLevel = d;
        reset();
    }

    public double getConfidenceLevel() {
        return this.m_ConfidenceLevel;
    }

    public String confidenceLevelTipText() {
        return "The confidence level to use when generating the confidence intervals (0-1).";
    }

    public void setRelativeWidths(boolean z) {
        this.m_RelativeWidths = z;
        reset();
    }

    public boolean getRelativeWidths() {
        return this.m_RelativeWidths;
    }

    public String relativeWidthsTipText() {
        return "If set to true, then the calculated widths will be relative ones, as they will get divided by the class value of the Instance.";
    }

    @Override // adams.data.weka.evaluator.AbstractCrossvalidatedInstanceEvaluator
    protected String findThreshold(Vector<SortedInterval> vector) {
        if (vector.size() == 0) {
            return "No intervals collected!";
        }
        Collections.sort(vector);
        this.m_MinWidth = vector.firstElement().getAverageWidth();
        this.m_MaxWidth = vector.get((int) Math.round(vector.size() * this.m_Threshold)).getAverageWidth();
        if (!isLoggingEnabled()) {
            return null;
        }
        Logger logger = getLogger();
        double d = this.m_MinWidth;
        double d2 = this.m_MaxWidth;
        logger.info("Computed thresholds: min=" + d + ", max=" + logger);
        return null;
    }

    @Override // adams.data.weka.evaluator.AbstractCrossvalidatedInstanceEvaluator
    protected Vector<SortedInterval> evaluate(Instances instances, Instances instances2) {
        Vector<SortedInterval> vector = new Vector<>();
        try {
            if (isLoggingEnabled()) {
                getLogger().info("Building classifier...");
            }
            this.m_Classifier.buildClassifier(instances);
            if (isLoggingEnabled()) {
                getLogger().info("Obtaining intervals...");
            }
            for (int i = 0; i < instances2.numInstances(); i++) {
                try {
                    vector.add(new SortedInterval(instances2.instance(i), this.m_Classifier.predictIntervals(instances2.instance(i), this.m_ConfidenceLevel), this.m_RelativeWidths));
                } catch (Exception e) {
                    getLogger().log(Level.SEVERE, "Error obtaining intervals for test instance #" + (i + 1) + ": " + instances2.instance(i), e);
                }
            }
        } catch (Exception e2) {
            getLogger().log(Level.SEVERE, "Failed to evaluate", e2);
        }
        return vector;
    }

    @Override // adams.data.weka.evaluator.AbstractInstanceEvaluator
    protected double doEvaluate(Instance instance) {
        double d;
        try {
            double calcAverageWidth = calcAverageWidth(this.m_Classifier.predictIntervals(instance, this.m_ConfidenceLevel));
            if (this.m_RelativeWidths) {
                calcAverageWidth = instance.classValue() == 0.0d ? Double.MAX_VALUE : calcAverageWidth / instance.classValue();
            }
            if (calcAverageWidth < this.m_MinWidth) {
                d = 1.0d;
            } else if (calcAverageWidth > this.m_MaxWidth) {
                double d2 = calcAverageWidth - this.m_MaxWidth;
                d = d2 > this.m_MaxWidth ? 0.0d : 0.5d - ((d2 / this.m_MaxWidth) / 2.0d);
            } else {
                d = 1.0d - (((calcAverageWidth - this.m_MinWidth) / (this.m_MaxWidth - this.m_MinWidth)) / 2.0d);
            }
        } catch (Exception e) {
            getLogger().log(Level.SEVERE, "Failed to evaluate", e);
            d = -1.0d;
        }
        return d;
    }

    protected static double calcWidth(double[] dArr) {
        return dArr[1] - dArr[0];
    }

    protected static double calcAverageWidth(double[][] dArr) {
        double d = 0.0d;
        for (double[] dArr2 : dArr) {
            d += calcWidth(dArr2);
        }
        if (dArr.length > 0) {
            d /= dArr.length;
        }
        return d;
    }
}
