Class CrossValidationGenerator

    • Field Detail

      • m_NumFolds

        protected int m_NumFolds
        the number of folds to use.
      • m_Randomization

        protected Randomization m_Randomization
        the randomization scheme.
      • m_Stratification

        protected Stratification m_Stratification
        the stratification scheme.
    • Constructor Detail

      • CrossValidationGenerator

        public CrossValidationGenerator()
        Initializes the cross-validation.
    • Method Detail

      • reset

        public void reset()
        Resets the scheme.
      • setNumFolds

        public void setNumFolds​(int value)
        Sets the number of folds to use.
        Parameters:
        value - the number of folds, LOO if <2
      • getNumFolds

        public int getNumFolds()
        Returns the number of folds in use.
        Returns:
        the number of folds, LOO if <2
      • setRandomization

        public void setRandomization​(Randomization value)
        Sets the randomization scheme to use.
        Parameters:
        value - the scheme
      • getRandomization

        public Randomization getRandomization()
        Returns the randomization scheme in use.
        Returns:
        the scheme
      • setStratification

        public void setStratification​(Stratification value)
        Sets the stratification scheme to use.
        Parameters:
        value - the scheme
      • getStratification

        public Stratification getStratification()
        Returns the stratification scheme in use.
        Returns:
        the scheme
      • generate

        public <T> List<FoldPair<Binnable<T>>> generate​(List<Binnable<T>> data)
        Generates cross-validation fold pairs. Temporarily adds the original index to the Binnable meta-data, using Wrapping.TMP_INDEX as key.
        Type Parameters:
        T - the payload type
        Parameters:
        data - the data to generate the pairs from
        Returns:
        the fold pairs
      • trainCV

        public static <T> List<Binnable<T>> trainCV​(List<Binnable<T>> data,
                                                    int numFolds,
                                                    int numFold)
        Creates the training set for one fold of a cross-validation on the dataset.
        Parameters:
        numFolds - the number of folds in the cross-validation. Must be greater than 1.
        numFold - 0 for the first fold, 1 for the second, ...
        Returns:
        the training set
        Throws:
        IllegalArgumentException - if the number of folds is less than 2 or greater than the number of instances.
      • trainCV

        public static <T> List<Binnable<T>> trainCV​(List<Binnable<T>> data,
                                                    int numFolds,
                                                    int numFold,
                                                    Randomization random)
        Creates the training set for one fold of a cross-validation on the dataset. The data is subsequently randomized based on the given random number generator.
        Parameters:
        numFolds - the number of folds in the cross-validation. Must be greater than 1.
        numFold - 0 for the first fold, 1 for the second, ...
        random - the random number generator
        Returns:
        the training set
        Throws:
        IllegalArgumentException - if the number of folds is less than 2 or greater than the number of instances.
      • testCV

        public static <T> List<Binnable<T>> testCV​(List<Binnable<T>> data,
                                                   int numFolds,
                                                   int numFold)
        Creates the test set for one fold of a cross-validation on the dataset.
        Parameters:
        numFolds - the number of folds in the cross-validation. Must be greater than 1.
        numFold - 0 for the first fold, 1 for the second, ...
        Returns:
        the test set as a set of weighted instances
        Throws:
        IllegalArgumentException - if the number of folds is less than 2 or greater than the number of instances.