Easy way to allow predict() to work on user-defined function?

I have a "model" that is not of a recognized class with a predict method, but a user-defined function. How can I set this up so predict(new_model) will return the function output?

model <- lm(Sepal.Length ~ ., data = iris)
new_model <- function(data) (predict(data) - 5)^2

# want this to work
# predict(new_model, iris)

Created on 2021-06-24 by the reprex package (v1.0.0)

Hello @arthur.t ,

this could help or at least suggest you a way forwards:

new_model <- function(data) (data - 5)^2
class(new_model) <- "arthur.t"
predict.arthur.t <- function(model,newdata) { model(newdata$Sepal.Length) }

predict(new_model, iris)
#>   [1] 0.01 0.01 0.09 0.16 0.00 0.16 0.16 0.00 0.36 0.01 0.16 0.04 0.04 0.49 0.64
#>  [16] 0.49 0.16 0.01 0.49 0.01 0.16 0.01 0.16 0.01 0.04 0.00 0.00 0.04 0.04 0.09
#>  [31] 0.04 0.16 0.04 0.25 0.01 0.00 0.25 0.01 0.36 0.01 0.00 0.25 0.36 0.00 0.01
#>  [46] 0.04 0.01 0.16 0.09 0.00 4.00 1.96 3.61 0.25 2.25 0.49 1.69 0.01 2.56 0.04
#>  [61] 0.00 0.81 1.00 1.21 0.36 2.89 0.36 0.64 1.44 0.36 0.81 1.21 1.69 1.21 1.96
#>  [76] 2.56 3.24 2.89 1.00 0.49 0.25 0.25 0.64 1.00 0.16 1.00 2.89 1.69 0.36 0.25
#>  [91] 0.25 1.21 0.64 0.00 0.36 0.49 0.49 1.44 0.01 0.49 1.69 0.64 4.41 1.69 2.25
#> [106] 6.76 0.01 5.29 2.89 4.84 2.25 1.96 3.24 0.49 0.64 1.96 2.25 7.29 7.29 1.00
#> [121] 3.61 0.36 7.29 1.69 2.89 4.84 1.44 1.21 1.96 4.84 5.76 8.41 1.96 1.69 1.21
#> [136] 7.29 1.69 1.96 1.00 3.61 2.89 3.61 0.64 3.24 2.89 2.89 1.69 2.25 1.44 0.81
Created on 2021-06-24 by the reprex package (v2.0.0)
1 Like

Thanks! Very enlightening.

I didn't realize it was so easy to create a minimal class with just a class name and predict method.

This topic was automatically closed 7 days after the last reply. New replies are no longer allowed.

If you have a query related to it or one of the replies, start a new topic and refer back with a link.