Package adams.data.tabnet
Class TabularRegression
- java.lang.Object
-
- adams.data.tabnet.TabularRegression
-
public final class TabularRegression extends Object
-
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static ai.djl.nn.Block
createBlock(ai.djl.zero.Performance performance, int featureSize, int labelSize)
static ai.djl.repository.zoo.ZooModel<ai.djl.basicdataset.tabular.ListFeatures,Float>
train(ai.djl.basicdataset.tabular.TabularDataset dataset, ai.djl.zero.Performance performance)
Trains a Model on a custom dataset.
-
-
-
Method Detail
-
createBlock
public static ai.djl.nn.Block createBlock(ai.djl.zero.Performance performance, int featureSize, int labelSize)
-
train
public static ai.djl.repository.zoo.ZooModel<ai.djl.basicdataset.tabular.ListFeatures,Float> train(ai.djl.basicdataset.tabular.TabularDataset dataset, ai.djl.zero.Performance performance) throws IOException, ai.djl.translate.TranslateException
Trains a Model on a custom dataset. Currently, trains a TabNet Model.In order to train on a custom dataset, you must create a custom
TabularDataset
to load your data.- Parameters:
dataset
- the data to train withperformance
- to determine the desired model tradeoffs- Returns:
- the model as a
ZooModel
- Throws:
IOException
- if the dataset could not be loadedai.djl.translate.TranslateException
- if the translator has errors
-
-