/*
 * Decompiled with CFR 0.152.
 */
package jsat.classifiers.calibration;

import jsat.classifiers.CategoricalResults;
import jsat.classifiers.DataPoint;
import jsat.classifiers.calibration.BinaryCalibration;
import jsat.classifiers.calibration.BinaryScoreClassifier;
import jsat.math.FastMath;

public class PlattCalibration
extends BinaryCalibration {
    private static final long serialVersionUID = 1099230240231262536L;
    private double A;
    private double B;
    private double maxIter = 100.0;
    private double minStep = 1.0E-10;
    private double sigma = 1.0E-12;

    public PlattCalibration(BinaryScoreClassifier base, BinaryCalibration.CalibrationMode mode) {
        super(base, mode);
    }

    @Override
    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(2);
        double p_1 = 1.0 / (1.0 + FastMath.exp(this.A * this.base.getScore(data) + this.B));
        cr.setProb(0, 1.0 - p_1);
        cr.setProb(1, p_1);
        return cr;
    }

    @Override
    protected void calibrate(boolean[] label, double[] deci, int len) {
        int prior1 = 0;
        for (boolean positive : label) {
            if (!positive) continue;
            ++prior1;
        }
        int prior0 = label.length - prior1;
        double hiTarget = ((double)prior1 + 1.0) / ((double)prior1 + 2.0);
        double loTarget = 1.0 / ((double)prior0 + 2.0);
        double[] t = new double[len];
        for (int i = 0; i < len; ++i) {
            t[i] = label[i] ? hiTarget : loTarget;
        }
        this.A = 0.0;
        this.B = FastMath.log(((double)prior0 + 1.0) / ((double)prior1 + 1.0));
        double fval = 0.0;
        for (int i = 0; i < len; ++i) {
            double fApB = deci[i] * this.A + this.B;
            if (fApB >= 0.0) {
                fval += t[i] * fApB + FastMath.log(1.0 + FastMath.exp(-fApB));
                continue;
            }
            fval += (t[i] - 1.0) * fApB + FastMath.log(1.0 + FastMath.exp(-fApB));
        }
        int it = 0;
        while ((double)it < this.maxIter) {
            double stepsize;
            double h11 = this.sigma;
            double h22 = this.sigma;
            double h21 = 0.0;
            double g1 = 0.0;
            double g2 = 0.0;
            for (int i = 0; i < len; ++i) {
                double q;
                double p;
                double fApB = deci[i] * this.A + this.B;
                if (fApB >= 0.0) {
                    p = FastMath.exp(-fApB) / (1.0 + FastMath.exp(-fApB));
                    q = 1.0 / (1.0 + FastMath.exp(-fApB));
                } else {
                    p = 1.0 / (1.0 + FastMath.exp(fApB));
                    q = FastMath.exp(fApB) / (1.0 + FastMath.exp(fApB));
                }
                double d2 = p * q;
                h11 += deci[i] * deci[i] * d2;
                h22 += d2;
                h21 += deci[i] * d2;
                double d1 = t[i] - p;
                g1 += deci[i] * d1;
                g2 += d1;
            }
            if (Math.abs(g1) < 1.0E-5 && Math.abs(g2) < 1.0E-5) break;
            double det = h11 * h22 - h21 * h21;
            double dA = -(h22 * g1 - h21 * g2) / det;
            double dB = -(-h21 * g1 + h11 * g2) / det;
            double gd = g1 * dA + g2 * dB;
            for (stepsize = 1.0; stepsize >= this.minStep; stepsize /= 2.0) {
                double newA = this.A + stepsize * dA;
                double newB = this.B + stepsize * dB;
                double newf = 0.0;
                for (int i = 0; i < len; ++i) {
                    double fApB = deci[i] * newA + newB;
                    if (fApB >= 0.0) {
                        newf += t[i] * fApB + FastMath.log(1.0 + FastMath.exp(-fApB));
                        continue;
                    }
                    newf += (t[i] - 1.0) * fApB + FastMath.log(1.0 + FastMath.exp(fApB));
                }
                if (!(newf < fval + 1.0E-4 * stepsize * gd)) continue;
                this.A = newA;
                this.B = newB;
                fval = newf;
                break;
            }
            if (stepsize < this.minStep) break;
            ++it;
        }
    }

    @Override
    public boolean supportsWeightedData() {
        return this.base.supportsWeightedData();
    }

    @Override
    public PlattCalibration clone() {
        PlattCalibration clone = new PlattCalibration(this.base.clone(), this.mode);
        clone.A = this.A;
        clone.B = this.B;
        clone.folds = this.folds;
        clone.holdOut = this.holdOut;
        clone.sigma = this.sigma;
        clone.minStep = this.minStep;
        clone.maxIter = this.maxIter;
        return clone;
    }
}

