Class CrossValidationGenerator
- java.lang.Object
-
- adams.core.logging.LoggingObject
-
- adams.core.logging.CustomLoggingLevelObject
-
- adams.data.splitgenerator.generic.crossvalidation.CrossValidationGenerator
-
- All Implemented Interfaces:
LoggingLevelHandler
,LoggingSupporter
,SizeOfHandler
,Serializable
public class CrossValidationGenerator extends CustomLoggingLevelObject
For generating cross-validation splits.- Author:
- FracPete (fracpete at waikato dot ac dot nz)
- See Also:
- Serialized Form
-
-
Field Summary
Fields Modifier and Type Field Description protected int
m_NumFolds
the number of folds to use.protected Randomization
m_Randomization
the randomization scheme.protected Stratification
m_Stratification
the stratification scheme.-
Fields inherited from class adams.core.logging.LoggingObject
m_Logger, m_LoggingIsEnabled, m_LoggingLevel
-
-
Constructor Summary
Constructors Constructor Description CrossValidationGenerator()
Initializes the cross-validation.
-
Method Summary
All Methods Static Methods Instance Methods Concrete Methods Modifier and Type Method Description <T> List<FoldPair<Binnable<T>>>
generate(List<Binnable<T>> data)
Generates cross-validation fold pairs.int
getNumFolds()
Returns the number of folds in use.Randomization
getRandomization()
Returns the randomization scheme in use.Stratification
getStratification()
Returns the stratification scheme in use.void
reset()
Resets the scheme.void
setNumFolds(int value)
Sets the number of folds to use.void
setRandomization(Randomization value)
Sets the randomization scheme to use.void
setStratification(Stratification value)
Sets the stratification scheme to use.static <T> List<Binnable<T>>
testCV(List<Binnable<T>> data, int numFolds, int numFold)
Creates the test set for one fold of a cross-validation on the dataset.static <T> List<Binnable<T>>
trainCV(List<Binnable<T>> data, int numFolds, int numFold)
Creates the training set for one fold of a cross-validation on the dataset.static <T> List<Binnable<T>>
trainCV(List<Binnable<T>> data, int numFolds, int numFold, Randomization random)
Creates the training set for one fold of a cross-validation on the dataset.-
Methods inherited from class adams.core.logging.CustomLoggingLevelObject
setLoggingLevel
-
Methods inherited from class adams.core.logging.LoggingObject
configureLogger, getLogger, getLoggingLevel, initializeLogging, isLoggingEnabled, sizeOf
-
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
-
Methods inherited from interface adams.core.logging.LoggingLevelHandler
getLoggingLevel
-
-
-
-
Field Detail
-
m_NumFolds
protected int m_NumFolds
the number of folds to use.
-
m_Randomization
protected Randomization m_Randomization
the randomization scheme.
-
m_Stratification
protected Stratification m_Stratification
the stratification scheme.
-
-
Method Detail
-
reset
public void reset()
Resets the scheme.
-
setNumFolds
public void setNumFolds(int value)
Sets the number of folds to use.- Parameters:
value
- the number of folds, LOO if <2
-
getNumFolds
public int getNumFolds()
Returns the number of folds in use.- Returns:
- the number of folds, LOO if <2
-
setRandomization
public void setRandomization(Randomization value)
Sets the randomization scheme to use.- Parameters:
value
- the scheme
-
getRandomization
public Randomization getRandomization()
Returns the randomization scheme in use.- Returns:
- the scheme
-
setStratification
public void setStratification(Stratification value)
Sets the stratification scheme to use.- Parameters:
value
- the scheme
-
getStratification
public Stratification getStratification()
Returns the stratification scheme in use.- Returns:
- the scheme
-
generate
public <T> List<FoldPair<Binnable<T>>> generate(List<Binnable<T>> data)
Generates cross-validation fold pairs. Temporarily adds the original index to the Binnable meta-data, usingWrapping.TMP_INDEX
as key.- Type Parameters:
T
- the payload type- Parameters:
data
- the data to generate the pairs from- Returns:
- the fold pairs
-
trainCV
public static <T> List<Binnable<T>> trainCV(List<Binnable<T>> data, int numFolds, int numFold)
Creates the training set for one fold of a cross-validation on the dataset.- Parameters:
numFolds
- the number of folds in the cross-validation. Must be greater than 1.numFold
- 0 for the first fold, 1 for the second, ...- Returns:
- the training set
- Throws:
IllegalArgumentException
- if the number of folds is less than 2 or greater than the number of instances.
-
trainCV
public static <T> List<Binnable<T>> trainCV(List<Binnable<T>> data, int numFolds, int numFold, Randomization random)
Creates the training set for one fold of a cross-validation on the dataset. The data is subsequently randomized based on the given random number generator.- Parameters:
numFolds
- the number of folds in the cross-validation. Must be greater than 1.numFold
- 0 for the first fold, 1 for the second, ...random
- the random number generator- Returns:
- the training set
- Throws:
IllegalArgumentException
- if the number of folds is less than 2 or greater than the number of instances.
-
testCV
public static <T> List<Binnable<T>> testCV(List<Binnable<T>> data, int numFolds, int numFold)
Creates the test set for one fold of a cross-validation on the dataset.- Parameters:
numFolds
- the number of folds in the cross-validation. Must be greater than 1.numFold
- 0 for the first fold, 1 for the second, ...- Returns:
- the test set as a set of weighted instances
- Throws:
IllegalArgumentException
- if the number of folds is less than 2 or greater than the number of instances.
-
-