Class TabularRegression


  • public final class TabularRegression
    extends Object
    • Method Detail

      • createBlock

        public static ai.djl.nn.Block createBlock​(ai.djl.zero.Performance performance,
                                                  int featureSize,
                                                  int labelSize)
      • train

        public static ai.djl.repository.zoo.ZooModel<ai.djl.basicdataset.tabular.ListFeatures,​Float> train​(ai.djl.basicdataset.tabular.TabularDataset dataset,
                                                                                                                 ai.djl.zero.Performance performance)
                                                                                                          throws IOException,
                                                                                                                 ai.djl.translate.TranslateException
        Trains a Model on a custom dataset. Currently, trains a TabNet Model.

        In order to train on a custom dataset, you must create a custom TabularDataset to load your data.

        Parameters:
        dataset - the data to train with
        performance - to determine the desired model tradeoffs
        Returns:
        the model as a ZooModel
        Throws:
        IOException - if the dataset could not be loaded
        ai.djl.translate.TranslateException - if the translator has errors