Package adams.data.tabnet
Class TabularRegression
- java.lang.Object
-
- adams.data.tabnet.TabularRegression
-
public final class TabularRegression extends Object
-
-
Method Summary
All Methods Static Methods Concrete Methods Modifier and Type Method Description static ai.djl.nn.BlockcreateBlock(ai.djl.zero.Performance performance, int featureSize, int labelSize)static ai.djl.repository.zoo.ZooModel<ai.djl.basicdataset.tabular.ListFeatures,Float>train(ai.djl.basicdataset.tabular.TabularDataset dataset, ai.djl.zero.Performance performance)Trains a Model on a custom dataset.
-
-
-
Method Detail
-
createBlock
public static ai.djl.nn.Block createBlock(ai.djl.zero.Performance performance, int featureSize, int labelSize)
-
train
public static ai.djl.repository.zoo.ZooModel<ai.djl.basicdataset.tabular.ListFeatures,Float> train(ai.djl.basicdataset.tabular.TabularDataset dataset, ai.djl.zero.Performance performance) throws IOException, ai.djl.translate.TranslateException
Trains a Model on a custom dataset. Currently, trains a TabNet Model.In order to train on a custom dataset, you must create a custom
TabularDatasetto load your data.- Parameters:
dataset- the data to train withperformance- to determine the desired model tradeoffs- Returns:
- the model as a
ZooModel - Throws:
IOException- if the dataset could not be loadedai.djl.translate.TranslateException- if the translator has errors
-
-