/*
 * Decompiled with CFR 0.152.
 */
package jsat.utils;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import jsat.SimpleDataSet;
import jsat.classifiers.CategoricalData;
import jsat.classifiers.DataPoint;
import jsat.distributions.ContinuousDistribution;
import jsat.distributions.Uniform;
import jsat.linear.DenseVector;
import jsat.utils.random.RandomUtil;

public class GridDataGenerator {
    private ContinuousDistribution noiseSource;
    private int[] dimensions;
    private Random rand;
    private CategoricalData[] catDataInfo;

    public GridDataGenerator(ContinuousDistribution noiseSource, Random rand, int ... dimensions) {
        this.noiseSource = noiseSource;
        this.rand = rand;
        this.dimensions = dimensions;
        for (int i = 0; i < dimensions.length; ++i) {
            if (dimensions[i] > 0) continue;
            throw new ArithmeticException("The " + i + "'th dimensino contains the non positive value " + dimensions[i]);
        }
    }

    public GridDataGenerator(ContinuousDistribution noiseSource, int ... dimensions) {
        this(noiseSource, RandomUtil.getRandom(), dimensions);
    }

    public GridDataGenerator() {
        this((ContinuousDistribution)new Uniform(-0.25, 0.25), RandomUtil.getRandom(), 2, 5);
    }

    private void addSamples(int[] curClass, int curDim, int samples, List<DataPoint> dataPoints, int[] dim) {
        if (curDim < this.dimensions.length - 1) {
            int i = 0;
            while (i < this.dimensions[curDim + 1]) {
                int[] nextDim = Arrays.copyOf(dim, dim.length);
                nextDim[curDim + 1] = i++;
                this.addSamples(curClass, curDim + 1, samples, dataPoints, nextDim);
            }
        } else {
            for (int i = 0; i < samples; ++i) {
                DenseVector dv = new DenseVector(dim.length);
                for (int j = 0; j < dim.length; ++j) {
                    dv.set(j, (double)dim[j] + this.noiseSource.invCdf(this.rand.nextDouble()));
                }
                dataPoints.add(new DataPoint(dv, new int[]{curClass[0]}, this.catDataInfo));
            }
            curClass[0] = curClass[0] + 1;
        }
    }

    public SimpleDataSet generateData(int samples) {
        int totalClasses = 1;
        for (int d : this.dimensions) {
            totalClasses *= d;
        }
        this.catDataInfo = new CategoricalData[]{new CategoricalData(totalClasses)};
        ArrayList<DataPoint> dataPoints = new ArrayList<DataPoint>(totalClasses * samples);
        int[] curClassPointer = new int[1];
        int i = 0;
        while (i < this.dimensions[0]) {
            int[] curDim = new int[this.dimensions.length];
            curDim[0] = i++;
            this.addSamples(curClassPointer, 0, samples, dataPoints, curDim);
        }
        return new SimpleDataSet(dataPoints);
    }
}

