train.model {SIAMCAT}R Documentation

Model training

Description

This function trains the a machine learning model on the training data

Usage

train.model(siamcat,
method = c("lasso","enet","ridge","lasso_ll", "ridge_ll", "randomForest"),
stratify = TRUE, modsel.crit = list("auc"), min.nonzero.coeff = 1,
param.set = NULL, verbose = 1)

Arguments

siamcat

object of class siamcat-class

method

string, specifies the type of model to be trained, may be one of these: c('lasso', 'enet', 'ridge', 'lasso_ll', 'ridge_ll', 'randomForest')

stratify

boolean, should the folds in the internal cross-validation be stratified?, defaults to TRUE

modsel.crit

list, specifies the model selection criterion during internal cross-validation, may contain these: c('auc', 'f1', 'acc', 'pr'), defaults to list('auc')

min.nonzero.coeff

integer number of minimum nonzero coefficients that should be present in the model (only for 'lasso', 'ridge', and 'enet', defaults to 1

param.set

a list of extra parameters for mlr run, may contain:

  • cost - for lasso_ll and ridge_ll

  • alpha for enet

  • ntree and mtry for RandomForrest.

Defaults to NULL

verbose

control output: 0 for no output at all, 1 for only information about progress and success, 2 for normal level of information and 3 for full debug information, defaults to 1

Details

This functions performs the training of the machine learning model and functions as an interface to the mlr-package.

The function expects a siamcat-class-object with a prepared cross-validation (see create.data.split) in the data_split-slot of the object. It then trains a model for each fold of the datasplit.

For the machine learning methods that require additional hyperparameters (e.g. lasso_ll), the optimal hyperparameters are tuned with the function tuneParams within the mlr-package.

The methods 'lasso', 'enet', and 'ridge' are implemented as mlr-taks using the 'classif.cvglmnet' Learner, 'lasso_ll' and 'ridge_ll' use the 'classif.LiblineaRL1LogReg' and the 'classif.LiblineaRL2LogReg' Learners respectively. The 'randomForest' method is implemented via the 'classif.randomForest' Learner.

Value

object of class siamcat-class with added model_list

Examples


    data(siamcat_example)
    # simple working example
    siamcat_validated <- train.model(siamcat_example, method='lasso')


[Package SIAMCAT version 1.0.0 Index]