Train word embeddings to a categorical variable using random forest.

textTrainRandomForest(
  x,
  y,
  x_append = NULL,
  append_first = FALSE,
  cv_method = "validation_split",
  outside_folds = 10,
  inside_folds = 3/4,
  strata = "y",
  outside_strata = TRUE,
  outside_breaks = 4,
  inside_strata = TRUE,
  inside_breaks = 4,
  mode_rf = "classification",
  preprocess_step_center = FALSE,
  preprocess_scale_center = FALSE,
  preprocess_PCA = NA,
  extremely_randomised_splitrule = "extratrees",
  mtry = c(1, 10, 20, 40),
  min_n = c(1, 10, 20, 40),
  trees = c(1000),
  eval_measure = "bal_accuracy",
  model_description = "Consider writing a description of your model here",
  multi_cores = "multi_cores_sys_default",
  save_output = "all",
  simulate.p.value = FALSE,
  seed = 2020,
  ...
)

Arguments

x

Word embeddings from textEmbed.

y

Categorical variable to predict.

x_append

(optional) Variables to be appended after the word embeddings (x); if wanting to preappend them before the word embeddings use the option first = TRUE. If not wanting to train with word embeddings, set x_append = NULL (default = null).

append_first

(boolean) Option to add variables before or after all word embeddings (default = FALSE).

cv_method

(character) Cross-validation method to use within a pipeline of nested outer and inner loops of folds (see nested_cv in rsample). Default is using cv_folds in the outside folds and "validation_split" using rsample::validation_split in the inner loop to achieve a development and assessment set (note that for validation_split the inside_folds should be a proportion, e.g., inside_folds = 3/4); whereas "cv_folds" uses rsample::vfold_cv to achieve n-folds in both the outer and inner loops.

outside_folds

(numeric) Number of folds for the outer folds (default = 10).

inside_folds

(numeric) Number of folds for the inner folds (default = 3/4).

strata

(string or tibble; default "y") Variable to stratify according; if a string the variable needs to be in the training set - if you want to stratify according to another variable you can include it as a tibble (please note you can only add 1 variable to stratify according). Can set it to NULL.

outside_strata

(boolean) Whether to stratify the outside folds.

outside_breaks

(numeric) The number of bins wanted to stratify a numeric stratification variable in the outer cross-validation loop (default = 4).

inside_strata

(boolean) Whether to stratify the outside folds.

inside_breaks

The number of bins wanted to stratify a numeric stratification variable in the inner cross-validation loop (default = 4).

mode_rf

Default is "classification" ("regression" is not supported yet).

preprocess_step_center

(boolean) Normalizes dimensions to have a mean of zero; default is set to FALSE For more info see (step_center in recipes).

preprocess_scale_center

(boolean) Normalizes dimensions to have a standard deviation of one; default is set to FALSE. For more info see (step_scale in recipes).

preprocess_PCA

Pre-processing threshold for PCA. Can select amount of variance to retain (e.g., .90 or as a grid c(0.80, 0.90)); or number of components to select (e.g., 10). (To skip this step, set preprocess_PCA to NA) Default is "min_halving", which is a function that selects the number of PCA components based on number of participants and feature (word embedding dimensions) in the data. The formula is: preprocess_PCA = round(max(min(number_features/2), number_participants/2), min(50, number_features))).

extremely_randomised_splitrule

Default is "extratrees", which thus implement a random forest; can also select: NULL, "gini" or "hellinger"; if these are selected your mtry settings will be overridden (see Geurts et al. (2006) Extremely randomized trees for details; and see the ranger r-package for details on implementations).

mtry

Hyper parameter that may be tuned; default: c(1, 20, 40),

min_n

Hyper parameter that may be tuned; default: c(1, 20, 40)

trees

Number of trees to use (default 1000).

eval_measure

(character) Measure to evaluate the models in order to select the best hyperparameters default "roc_auc"; see also "accuracy", "bal_accuracy", "sens", "spec", "precision", "kappa", "f_measure".

model_description

(character) Text to describe your model (optional; good when sharing the model with others).

multi_cores

If TRUE it enables the use of multiple cores if the computer system allows for it (i.e., only on unix, not windows). Hence it makes the analyses considerably faster to run. Default is "multi_cores_sys_default", where it automatically uses TRUE for Mac and Linux and FALSE for Windows. Note that having it to TRUE does not enable reproducable results at the moment (i.e., cannot set seed).

save_output

(character) Option not to save all output; default "all". See also "only_results" and "only_results_predictions".

simulate.p.value

(Boolean) From fisher.test: a logical indicating whether to compute p-values by Monte Carlo simulation, in larger than 2 × 2 tables.

seed

(numeric) Set different seed (default = 2020).

...

For example settings in yardstick::accuracy to set event_level (e.g., event_level = "second").

Value

A list with roc_curve_data, roc_curve_plot, truth and predictions, preprocessing_recipe, final_model, model_description chisq and fishers test as well as evaluation measures, e.g., including accuracy, f_meas and roc_auc (for details on these measures see the yardstick r-package documentation).

Examples

# Examines how well the embeddings from column "harmonywords" in
# Language_based_assessment_data_8 can binarily classify gender.

if (FALSE) {
trained_model <- textTrainRandomForest(
  x = word_embeddings_4$texts$harmonywords,
  y = as.factor(Language_based_assessment_data_8$gender),
  trees = c(1000, 1500),
  mtry = c(1), # this is short because of testing
  min_n = c(1), # this is short because of testing
  multi_cores = FALSE # This is FALSE due to CRAN testing and Windows machines.
)


# Examine results (t-value, degree of freedom (df), p-value,
# alternative-hypothesis, confidence interval, correlation coefficient).

trained_model$results
}