/*
 * Decompiled with CFR 0.152.
 */
package jsat.distributions.kernels;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import jsat.distributions.kernels.KernelPoint;
import jsat.distributions.kernels.KernelTrick;
import jsat.linear.IndexValue;
import jsat.linear.Vec;
import jsat.utils.DoubleList;

public class KernelPoints {
    private KernelTrick k;
    private double errorTolerance;
    private KernelPoint.BudgetStrategy budgetStrategy = KernelPoint.BudgetStrategy.PROJECTION;
    private int maxBudget = Integer.MAX_VALUE;
    private List<KernelPoint> points;

    public KernelPoints(KernelTrick k, int points, double errorTolerance) {
        this(k, points, errorTolerance, true);
    }

    public KernelPoints(KernelTrick k, int points, double errorTolerance, boolean mergeGrams) {
        if (points < 1) {
            throw new IllegalArgumentException("Number of points must be positive, not " + points);
        }
        this.k = k;
        this.errorTolerance = errorTolerance;
        this.points = new ArrayList<KernelPoint>(points);
        this.points.add(new KernelPoint(k, errorTolerance));
        this.points.get(0).setMaxBudget(this.maxBudget);
        this.points.get(0).setBudgetStrategy(this.budgetStrategy);
        for (int i = 1; i < points; ++i) {
            this.addNewKernelPoint();
        }
    }

    public KernelPoints(KernelPoints toCopy) {
        this.k = toCopy.k.clone();
        this.errorTolerance = toCopy.errorTolerance;
        this.points = new ArrayList<KernelPoint>(toCopy.points.size());
        if (toCopy.points.get(0).getBasisSize() == 0) {
            for (int i = 0; i < toCopy.points.size(); ++i) {
                this.points.add(new KernelPoint(this.k, this.errorTolerance));
            }
        } else {
            KernelPoint source = this.points.get(0).clone();
            for (int i = 1; i < toCopy.points.size(); ++i) {
                KernelPoint toAdd = new KernelPoint(this.k, this.errorTolerance);
                this.standardMove(toAdd, source);
                toAdd.kernelAccel = source.kernelAccel;
                toAdd.vecs = source.vecs;
                toAdd.alpha = new DoubleList(toCopy.points.get((int)i).alpha);
            }
        }
    }

    public void setBudgetStrategy(KernelPoint.BudgetStrategy budgetStrategy) {
        this.budgetStrategy = budgetStrategy;
        for (KernelPoint kp : this.points) {
            kp.setBudgetStrategy(budgetStrategy);
        }
    }

    public KernelPoint.BudgetStrategy getBudgetStrategy() {
        return this.budgetStrategy;
    }

    public KernelTrick getKernel() {
        return this.k;
    }

    public void setErrorTolerance(double errorTolerance) {
        if (Double.isNaN(errorTolerance) || errorTolerance < 0.0 || errorTolerance > 1.0) {
            throw new IllegalArgumentException("Error tolerance must be in [0, 1], not " + errorTolerance);
        }
        this.errorTolerance = errorTolerance;
        for (KernelPoint kp : this.points) {
            kp.setErrorTolerance(errorTolerance);
        }
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public void setMaxBudget(int maxBudget) {
        if (maxBudget < 1) {
            throw new IllegalArgumentException("Budget must be positive, not " + maxBudget);
        }
        this.maxBudget = maxBudget;
        for (KernelPoint kp : this.points) {
            kp.setMaxBudget(maxBudget);
        }
    }

    public int getMaxBudget() {
        return this.maxBudget;
    }

    public double getSqrdNorm(int k) {
        return this.points.get(k).getSqrdNorm();
    }

    public double dot(int k, Vec x, List<Double> qi) {
        return this.points.get(k).dot(x, qi);
    }

    public double[] dot(Vec x, List<Double> qi) {
        double[] dots = new double[this.points.size()];
        List<Vec> vecs = this.points.get((int)0).vecs;
        List<Double> cache = this.points.get((int)0).kernelAccel;
        for (int i = 0; i < vecs.size(); ++i) {
            double k_ix = this.k.eval(i, x, qi, vecs, cache);
            for (int j = 0; j < this.points.size(); ++j) {
                double alpha = this.points.get((int)j).alpha.getD(i);
                if (alpha == 0.0) continue;
                int n = j;
                dots[n] = dots[n] + k_ix * alpha;
            }
        }
        return dots;
    }

    public double dot(int k, KernelPoint x) {
        return this.points.get(k).dot(x);
    }

    public double dot(int k, KernelPoints X, int j) {
        return this.points.get(k).dot(X.points.get(j));
    }

    public double dist(int k, Vec x, List<Double> qi) {
        return this.points.get(k).dist(x, qi);
    }

    public double dist(int k, KernelPoint x) {
        return this.points.get(k).dist(x);
    }

    public double dist(int k, KernelPoints X, int j) {
        return this.points.get(k).dist(X.points.get(j));
    }

    public void mutableMultiply(int k, double c) {
        this.points.get(k).mutableMultiply(c);
    }

    public void mutableMultiply(double c) {
        for (KernelPoint kp : this.points) {
            kp.mutableMultiply(c);
        }
    }

    public void mutableAdd(int k, double c, Vec x_t, List<Double> qi) {
    }

    public void mutableAdd(Vec x_t, Vec cs, List<Double> qi) {
        int origSize = this.getBasisSize();
        if (cs.nnz() == 0) {
            return;
        }
        if (this.budgetStrategy == KernelPoint.BudgetStrategy.PROJECTION) {
            for (IndexValue iv : cs) {
                KernelPoint kp_i;
                int i;
                int k = iv.getIndex();
                KernelPoint kp_k = this.points.get(k);
                double c = iv.getValue();
                if (kp_k.getBasisSize() == 0) {
                    kp_k.mutableAdd(c, x_t, qi);
                    for (i = 0; i < this.points.size(); ++i) {
                        if (i == k) continue;
                        kp_i = this.points.get(i);
                        this.standardMove(kp_i, kp_k);
                        kp_i.kernelAccel = kp_k.kernelAccel;
                        kp_i.vecs = kp_k.vecs;
                        kp_i.alpha = new DoubleList(16);
                        kp_i.alpha.add(0.0);
                    }
                } else {
                    kp_k.mutableAdd(c, x_t, qi);
                    if (origSize != kp_k.getBasisSize()) {
                        for (i = 0; i < this.points.size(); ++i) {
                            if (i == k) continue;
                            kp_i = this.points.get(i);
                            this.standardMove(kp_i, kp_k);
                            kp_i.alpha.add(0.0);
                        }
                    }
                }
                origSize = this.getBasisSize();
            }
        } else if (this.budgetStrategy == KernelPoint.BudgetStrategy.MERGE_RBF) {
            Iterator<IndexValue> cIter = cs.getNonZeroIterator();
            if (this.getBasisSize() < this.maxBudget) {
                IndexValue firstIndx = cIter.next();
                KernelPoint kp_k = this.points.get(firstIndx.getIndex());
                kp_k.mutableAdd(firstIndx.getValue(), x_t, qi);
                while (cIter.hasNext()) {
                    IndexValue iv = cIter.next();
                    this.points.get((int)iv.getIndex()).alpha.add(iv.getValue());
                }
                this.addMissingZeros();
            } else {
                KernelPoint kp_k = this.points.get(0);
                kp_k.vecs.add(x_t);
                if (kp_k.kernelAccel != null) {
                    kp_k.kernelAccel.addAll(qi);
                }
                for (IndexValue iv : cs) {
                    this.points.get((int)iv.getIndex()).alpha.add(iv.getValue());
                }
                this.addMissingZeros();
                int m = 0;
                double alpha_m = 0.0;
                for (KernelPoint kp : this.points) {
                    alpha_m += Math.pow(kp.alpha.getD(m), 2.0);
                }
                for (int i = 1; i < kp_k.alpha.size(); ++i) {
                    double tmp = 0.0;
                    for (KernelPoint kp : this.points) {
                        tmp += Math.pow(kp.alpha.getD(i), 2.0);
                    }
                    if (!(tmp < alpha_m)) continue;
                    alpha_m = tmp;
                    m = i;
                }
                double minLoss = Double.POSITIVE_INFINITY;
                int n = -1;
                double n_h = 0.0;
                double tol = 0.001;
                double n_k_mz = 0.0;
                double n_k_nz = 0.0;
                while (n == -1) {
                    for (int i = 0; i < kp_k.alpha.size(); ++i) {
                        if (i == m) continue;
                        double a_m = 0.0;
                        double a_n = 0.0;
                        for (KernelPoint kp : this.points) {
                            double a2;
                            double a1 = kp.alpha.getD(m);
                            double normalize = a1 + (a2 = kp.alpha.getD(i));
                            if (normalize < 1.0E-7) continue;
                            a_m += a1 / normalize;
                            a_n += a2 / normalize;
                        }
                        if (Math.abs(a_m + a_n) < tol) break;
                        double k_mn = this.k.eval(i, m, kp_k.vecs, kp_k.kernelAccel);
                        double h = KernelPoint.getH(k_mn, a_m, a_n);
                        double k_mz = Math.pow(k_mn, (1.0 - h) * (1.0 - h));
                        double k_nz = Math.pow(k_mn, h * h);
                        double loss = 0.0;
                        for (KernelPoint kp : this.points) {
                            double aml = kp.alpha.getD(m);
                            double anl = kp.alpha.getD(i);
                            double alpha_z = aml * k_mz + anl * k_nz;
                            loss += aml * aml + anl * anl + 2.0 * k_mn * aml * anl - alpha_z * alpha_z;
                        }
                        if (!(loss < minLoss)) continue;
                        minLoss = loss;
                        n = i;
                        n_h = h;
                        n_k_mz = k_mz;
                        n_k_nz = k_nz;
                    }
                    tol /= 10.0;
                }
                Vec n_z = kp_k.vecs.get(m).multiply(n_h);
                n_z.mutableAdd(1.0 - n_h, kp_k.vecs.get(n));
                List<Double> nz_qi = this.k.getQueryInfo(n_z);
                for (int z = 0; z < this.points.size(); ++z) {
                    KernelPoint kp = this.points.get(z);
                    double aml = kp.alpha.getD(m);
                    double anl = kp.alpha.getD(n);
                    double alpha_z = aml * n_k_mz + anl * n_k_nz;
                    kp.finalMergeStep(m, n, n_z, nz_qi, alpha_z, z == 0);
                }
            }
        } else if (this.budgetStrategy == KernelPoint.BudgetStrategy.STOP) {
            if (this.getBasisSize() < this.maxBudget) {
                this.points.get((int)0).vecs.add(x_t);
                if (this.points.get((int)0).kernelAccel != null) {
                    this.points.get((int)0).kernelAccel.addAll(qi);
                }
                for (IndexValue iv : cs) {
                    this.points.get((int)iv.getIndex()).alpha.add(iv.getValue());
                }
                this.addMissingZeros();
            }
        } else if (this.budgetStrategy == KernelPoint.BudgetStrategy.RANDOM) {
            if (this.getBasisSize() >= this.maxBudget) {
                int toRemove = new Random().nextInt(this.getBasisSize());
                if (this.getBasisSize() == this.maxBudget) {
                    this.points.get(0).removeIndex(toRemove);
                }
                for (int i = 1; i < this.points.size(); ++i) {
                    this.points.get(i).removeIndex(toRemove);
                }
            }
            this.points.get((int)0).vecs.add(x_t);
            if (this.points.get((int)0).kernelAccel != null) {
                this.points.get((int)0).kernelAccel.addAll(qi);
            }
            for (IndexValue iv : cs) {
                this.points.get((int)iv.getIndex()).alpha.add(iv.getValue());
            }
            this.addMissingZeros();
        } else {
            throw new RuntimeException("BUG: Report Me!");
        }
    }

    public void addNewKernelPoint() {
        KernelPoint source = this.points.get(0);
        KernelPoint toAdd = new KernelPoint(this.k, this.errorTolerance);
        toAdd.setMaxBudget(this.maxBudget);
        toAdd.setBudgetStrategy(this.budgetStrategy);
        this.standardMove(toAdd, source);
        toAdd.kernelAccel = source.kernelAccel;
        toAdd.vecs = source.vecs;
        toAdd.alpha = new DoubleList(source.alpha.size());
        for (int i = 0; i < source.alpha.size(); ++i) {
            toAdd.alpha.add(0.0);
        }
        this.points.add(toAdd);
    }

    private void standardMove(KernelPoint destination, KernelPoint source) {
        destination.InvK = source.InvK;
        destination.InvKExpanded = source.InvKExpanded;
        destination.K = source.K;
        destination.KExpanded = source.KExpanded;
    }

    public int getBasisSize() {
        return this.points.get(0).getBasisSize();
    }

    public List<Vec> getRawBasisVecs() {
        ArrayList<Vec> vecs = new ArrayList<Vec>(this.getBasisSize());
        vecs.addAll(this.points.get((int)0).vecs);
        return vecs;
    }

    public int size() {
        return this.points.size();
    }

    public KernelPoints clone() {
        return new KernelPoints(this);
    }

    private void addMissingZeros() {
        for (int i = 0; i < this.points.size(); ++i) {
            while (this.points.get((int)i).alpha.size() < this.points.get((int)0).vecs.size()) {
                this.points.get((int)i).alpha.add(0.0);
            }
        }
    }
}

