'How to Create Parametric Survival Learner for MLR in R

I am following the instructions (https://mlr.mlr-org.com/articles/tutorial/create_learner.html) to create a parametric survival learner to use with MLR. My code is below.

When I try to make the MakeLearner(id = "AFT", "surv.parametric"), I get an error dist is missing and no default is set even though I already specified the dist default in my code to be "weibull".

makeRLearner.surv.parametric = function() {
  makeRLearnerSurv(
    cl = "surv.parametric",
    package = "survival",
    par.set = makeParamSet(
      makeDiscreteLearnerParam(id = "dist", default = "weibull", 
                               values = c("weibull", "exponential", "lognormal", "loglogistic")),
    ),
    properties = c("numerics", "factors", "weights", "prob", "rcens"),
    name = "Parametric Survival Model",
    short.name = "Parametric",
    note = "This is created based on MLR3 surv.parametric learner"
  )
}

trainLearner.surv.parametric = function (.learner, .task, .subset, .weights = NULL, ...) 
{
  f    = getTaskFormula(.task)
  data = getTaskData(.task, subset = .subset)
  if (is.null(.weights)) {
    mod = survival::survreg(formula = f, data = data, ...)
  }
  else {
    mod = survival::survreg(formula = f, data = data, weights = .weights, ...)
  }
  mod
}

predictLearner.surv.parametric = function (.learner, .model, .newdata, ...) 
{
  survival::predict.survreg(.model$learner.model, newdata = .newdata, type = "response", ...)
}
mlr


Solution 1:[1]

Based on here, the prediction function needs to return linear predictors and that would be lp not response. Also, the cindex function of MLR does not seem to be consistent with the output of SurvReg. Based on this discussion, adding a minus seems to resolve the issue. So the prediction function would be as below.

predictLearner.surv.reg = function(.learner, .model, .newdata, ...) {
  -predict(.model$learner.model, newdata = .newdata, type = "lp", ...)
}

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Mary B