/*
 * Decompiled with CFR 0.152.
 */
package org.cleartk.classifier.sigmoid;

import java.util.logging.Logger;
import org.cleartk.classifier.sigmoid.Sigmoid;

public class LinWengPlatt {
    public static Sigmoid fit(double[] decisionValues, boolean[] labels) throws ConvergenceFailure {
        int iterations;
        assert (decisionValues.length == labels.length);
        int nPlus = 0;
        for (boolean l : labels) {
            if (!l) continue;
            ++nPlus;
        }
        int nMinus = labels.length - nPlus;
        int maxIterations = 100;
        double minimumStepsize = 1.0E-10;
        double sigma = 1.0E-12;
        double hiTarget = ((double)nPlus + 1.0) / ((double)nPlus + 2.0);
        double loTarget = 1.0 / ((double)nMinus + 2.0);
        int n = nMinus + nPlus;
        double[] t = new double[n];
        for (int i = 0; i < n; ++i) {
            t[i] = labels[i] ? hiTarget : loTarget;
        }
        double a = 0.0;
        double b = Math.log(((double)nMinus + 1.0) / ((double)nPlus + 1.0));
        double f = 0.0;
        for (int i = 0; i < n; ++i) {
            double fApB = decisionValues[i] * a + b;
            if (fApB >= 0.0) {
                f += t[i] * fApB + Math.log(1.0 + Math.exp(-fApB));
                continue;
            }
            f += (t[i] - 1.0) * fApB + Math.log(1.0 + Math.exp(fApB));
        }
        for (iterations = 0; iterations < maxIterations; ++iterations) {
            double stepsize;
            double h11 = sigma;
            double h22 = sigma;
            double h21 = 0.0;
            double g1 = 0.0;
            double g2 = 0.0;
            for (int i = 0; i < n; ++i) {
                double q;
                double p;
                double fApB = decisionValues[i] * a + b;
                if (fApB >= 0.0) {
                    p = Math.exp(-fApB) / (1.0 + Math.exp(-fApB));
                    q = 1.0 / (1.0 + Math.exp(-fApB));
                } else {
                    p = 1.0 / (1.0 + Math.exp(fApB));
                    q = Math.exp(fApB) / (1.0 + Math.exp(fApB));
                }
                double d2 = p * q;
                h11 += decisionValues[i] * decisionValues[i] * d2;
                h22 += d2;
                h21 += decisionValues[i] * d2;
                double d1 = t[i] - p;
                g1 += decisionValues[i] * d1;
                g2 += d1;
            }
            if (Math.abs(g1) < 1.0E-5 && Math.abs(g2) < 1.0E-5) break;
            double det = h11 * h22 - h21 * h21;
            double dA = -(h22 * g1 - h21 * g2) / det;
            double dB = -(-h21 * g1 + h11 * g2) / det;
            double gd = g1 * dA + g2 * dB;
            for (stepsize = 1.0; stepsize >= minimumStepsize; stepsize /= 2.0) {
                double newA = a + stepsize * dA;
                double newB = b + stepsize * dB;
                double newf = 0.0;
                for (int i = 1; i < n; ++i) {
                    double fApB = decisionValues[i] * newA + newB;
                    if (fApB >= 0.0) {
                        newf += t[i] * fApB + Math.log(1.0 + Math.exp(-fApB));
                        continue;
                    }
                    newf += (t[i] - 1.0) * fApB + Math.log(1.0 + Math.exp(fApB));
                }
                if (!(newf < f + 1.0E-4 * stepsize * gd)) continue;
                a = newA;
                b = newB;
                f = newf;
                break;
            }
            if (!(stepsize < minimumStepsize)) continue;
            Logger logger = Logger.getLogger(LinWengPlatt.class.getName());
            logger.fine("line search failure");
            break;
        }
        if (iterations >= maxIterations) {
            throw new ConvergenceFailure("Reaching maximum iterations");
        }
        return new Sigmoid(a, b);
    }

    public static class ConvergenceFailure
    extends Exception {
        private static final long serialVersionUID = -7570320408478887106L;

        public ConvergenceFailure(String message) {
            super(message);
        }
    }
}

