/*
 * Decompiled with CFR 0.152.
 */
package jsat.math;

import jsat.linear.Vec;
import jsat.math.Function;
import jsat.math.FunctionBase;

public class MathTricks {
    public static final Function sqrtFunc = new FunctionBase(){
        private static final long serialVersionUID = -5898515135319116600L;

        @Override
        public double f(Vec x) {
            return Math.sqrt(x.get(0));
        }
    };
    public static final Function sqrdFunc = new FunctionBase(){
        private static final long serialVersionUID = 6831886040279358142L;

        @Override
        public double f(Vec x) {
            double xx = x.get(0);
            return xx * xx;
        }
    };
    public static final Function invsFunc = new FunctionBase(){
        private static final long serialVersionUID = -7745316806635400174L;

        @Override
        public double f(Vec x) {
            return 1.0 / x.get(0);
        }
    };
    public static final Function logFunc = new FunctionBase(){
        private static final long serialVersionUID = -4653355640520837353L;

        @Override
        public double f(Vec x) {
            return Math.log(x.get(0));
        }
    };
    public static final Function expFunc = new FunctionBase(){
        private static final long serialVersionUID = 7075309263321302492L;

        @Override
        public double f(Vec x) {
            return Math.exp(x.get(0));
        }
    };
    public static final Function absFunc = new FunctionBase(){
        private static final long serialVersionUID = -3706702191562872641L;

        @Override
        public double f(Vec x) {
            return Math.abs(x.get(0));
        }
    };

    private MathTricks() {
    }

    public static double max(double ... vals) {
        double m = Double.NEGATIVE_INFINITY;
        for (double v : vals) {
            m = Math.max(v, m);
        }
        return m;
    }

    public static double min(double ... vals) {
        double m = Double.NEGATIVE_INFINITY;
        for (double v : vals) {
            m = Math.min(v, m);
        }
        return m;
    }

    public static double logSumExp(Vec vals, double maxValue) {
        double expSum = 0.0;
        for (int i = 0; i < vals.length(); ++i) {
            expSum += Math.exp(vals.get(i) - maxValue);
        }
        return maxValue + Math.log(expSum);
    }

    public static double logSumExp(double[] vals, double maxValue) {
        double expSum = 0.0;
        for (int i = 0; i < vals.length; ++i) {
            expSum += Math.exp(vals[i] - maxValue);
        }
        return maxValue + Math.log(expSum);
    }

    public static void softmax(double[] x, boolean implicitExtra) {
        int c;
        double max = implicitExtra ? 1.0 : Double.NEGATIVE_INFINITY;
        for (int i = 0; i < x.length; ++i) {
            max = MathTricks.max(max, x[i]);
        }
        double z = implicitExtra ? Math.exp(-max) : 0.0;
        for (c = 0; c < x.length; ++c) {
            x[c] = Math.exp(x[c] - max);
            z += x[c];
        }
        c = 0;
        while (c < x.length) {
            int n = c++;
            x[n] = x[n] / z;
        }
    }

    public static void softmax(Vec x, boolean implicitExtra) {
        double max = implicitExtra ? 1.0 : Double.NEGATIVE_INFINITY;
        max = MathTricks.max(max, x.max());
        double z = implicitExtra ? Math.exp(-max) : 0.0;
        for (int c = 0; c < x.length(); ++c) {
            double newVal = Math.exp(x.get(c) - max);
            x.set(c, newVal);
            z += newVal;
        }
        x.mutableDivide(z);
    }

    public static double hornerPolyR(double[] coef, double x) {
        double result = 0.0;
        for (double c : coef) {
            result = result * x + c;
        }
        return result;
    }

    public static double hornerPoly(double[] coef, double x) {
        double result = 0.0;
        for (int i = coef.length - 1; i >= 0; --i) {
            result = result * x + coef[i];
        }
        return result;
    }
}

