Add predictions to a data frame
Usage
add_predictions(data, model, var = "pred", type = NULL)
spread_predictions(data, ..., type = NULL)
gather_predictions(data, ..., .pred = "pred", .model = "model", type = NULL)
Arguments
- data
A data frame used to generate the predictions.
- model
add_predictions
takes a singlemodel
;- var
The name of the output column, default value is
pred
- type
Prediction type, passed on to
stats::predict()
. Consultpredict()
documentation for givenmodel
to determine valid values.- ...
gather_predictions
andspread_predictions
take multiple models. The name will be taken from either the argument name of the name of the model.- .pred, .model
The variable names used by
gather_predictions
.
Value
A data frame. add_prediction
adds a single new column,
with default name pred
, to the input data
.
spread_predictions
adds one column for each model. gather_predictions
adds two columns .model
and .pred
, and repeats the input rows for each
model.
Examples
df <- tibble::tibble(
x = sort(runif(100)),
y = 5 * x + 0.5 * x ^ 2 + 3 + rnorm(length(x))
)
plot(df)
m1 <- lm(y ~ x, data = df)
grid <- data.frame(x = seq(0, 1, length = 10))
grid %>% add_predictions(m1)
#> x pred
#> 1 0.0000000 2.989671
#> 2 0.1111111 3.599630
#> 3 0.2222222 4.209588
#> 4 0.3333333 4.819547
#> 5 0.4444444 5.429505
#> 6 0.5555556 6.039464
#> 7 0.6666667 6.649422
#> 8 0.7777778 7.259381
#> 9 0.8888889 7.869339
#> 10 1.0000000 8.479298
m2 <- lm(y ~ poly(x, 2), data = df)
grid %>% spread_predictions(m1, m2)
#> x m1 m2
#> 1 0.0000000 2.989671 2.953507
#> 2 0.1111111 3.599630 3.584332
#> 3 0.2222222 4.209588 4.209973
#> 4 0.3333333 4.819547 4.830430
#> 5 0.4444444 5.429505 5.445702
#> 6 0.5555556 6.039464 6.055791
#> 7 0.6666667 6.649422 6.660696
#> 8 0.7777778 7.259381 7.260416
#> 9 0.8888889 7.869339 7.854953
#> 10 1.0000000 8.479298 8.444306
grid %>% gather_predictions(m1, m2)
#> model x pred
#> 1 m1 0.0000000 2.989671
#> 2 m1 0.1111111 3.599630
#> 3 m1 0.2222222 4.209588
#> 4 m1 0.3333333 4.819547
#> 5 m1 0.4444444 5.429505
#> 6 m1 0.5555556 6.039464
#> 7 m1 0.6666667 6.649422
#> 8 m1 0.7777778 7.259381
#> 9 m1 0.8888889 7.869339
#> 10 m1 1.0000000 8.479298
#> 11 m2 0.0000000 2.953507
#> 12 m2 0.1111111 3.584332
#> 13 m2 0.2222222 4.209973
#> 14 m2 0.3333333 4.830430
#> 15 m2 0.4444444 5.445702
#> 16 m2 0.5555556 6.055791
#> 17 m2 0.6666667 6.660696
#> 18 m2 0.7777778 7.260416
#> 19 m2 0.8888889 7.854953
#> 20 m2 1.0000000 8.444306