@@ -186,6 +186,121 @@ tidypredict_test_default <- function(
186186 structure(results , class = c(" tidypredict_test" , " list" ))
187187}
188188
189+ # ' @export
190+ tidypredict_test.glmnet <- function (
191+ model ,
192+ df = model $ model ,
193+ threshold = 0.000000000001 ,
194+ include_intervals = FALSE ,
195+ max_rows = NULL ,
196+ xg_df = NULL
197+ ) {
198+ offset <- model $ call $ offset
199+ ismodels <- paste0(colnames(model $ model ), collapse = " " ) ==
200+ paste0(colnames(df ), collapse = " " )
201+
202+ if (! is.null(offset ) && ismodels ) {
203+ index <- colnames(df ) == " (offset)"
204+ colnames(df ) <- replace(colnames(df ), index , as.character(offset ))
205+ }
206+
207+ interval <- " none"
208+ if (include_intervals ) {
209+ interval <- " prediction"
210+ }
211+
212+ if (is.numeric(max_rows )) {
213+ df <- head(df , max_rows )
214+ }
215+
216+ preds <- predict(model , as.matrix(df ), interval = interval , type = " response" )
217+
218+ if (! include_intervals ) {
219+ base <- data.frame (fit = as.vector(preds ), row.names = NULL )
220+ } else {
221+ base <- as.data.frame(preds )
222+ }
223+
224+ te <- tidypredict_to_column(
225+ df ,
226+ model ,
227+ add_interval = include_intervals ,
228+ vars = c(" fit_te" , " upr_te" , " lwr_te" )
229+ )
230+ if (include_intervals ) {
231+ te <- te [, c(" fit_te" , " upr_te" , " lwr_te" )]
232+ } else {
233+ te <- data.frame (fit_te = te [, " fit_te" ])
234+ }
235+
236+ raw_results <- cbind(base , te )
237+ raw_results $ fit_diff <- raw_results $ fit - raw_results $ fit_te
238+ raw_results $ fit_threshold <- abs(raw_results $ fit_diff ) > threshold
239+
240+ if (include_intervals ) {
241+ raw_results $ lwr_diff <- abs(raw_results $ lwr - raw_results $ lwr_te )
242+ raw_results $ upr_diff <- abs(raw_results $ upr - raw_results $ upr_te )
243+ raw_results $ lwr_threshold <- raw_results $ lwr_diff > threshold
244+ raw_results $ upr_threshold <- raw_results $ upr_diff > threshold
245+ }
246+
247+ rowid <- seq_len(nrow(raw_results ))
248+ raw_results <- cbind(data.frame (rowid ), raw_results )
249+
250+ threshold_df <- data.frame (fit_threshold = sum(raw_results $ fit_threshold ))
251+ if (include_intervals ) {
252+ threshold_df $ lwr_threshold <- sum(raw_results $ lwr_threshold )
253+ threshold_df $ upr_threshold <- sum(raw_results $ upr_threshold )
254+ }
255+
256+ alert <- any(threshold_df > 0 )
257+
258+ message <- paste0(
259+ " tidypredict test results\n " ,
260+ " Difference threshold: " ,
261+ threshold ,
262+ " \n "
263+ )
264+
265+ if (alert ) {
266+ difference <- data.frame (fit_diff = max(raw_results $ fit_diff ))
267+ if (include_intervals ) {
268+ difference $ lwr_diff <- max(raw_results $ lwr_diff )
269+ difference $ upr_diff <- max(raw_results $ upr_diff )
270+ }
271+ message <- paste0(
272+ message ,
273+ " \n Fitted records above the threshold: " ,
274+ threshold_df $ fit_threshold ,
275+ if (! is.null(threshold_df $ lwr_threshold )) {
276+ " \n Lower interval records above the threshold: "
277+ },
278+ threshold_df $ lwr_threshold ,
279+ if (! is.null(threshold_df $ upr_threshold )) {
280+ " \n Upper interval records above the threshold: "
281+ },
282+ threshold_df $ upr_threshold ,
283+ " \n\n Fit max difference:" ,
284+ difference $ upr_diff ,
285+ " \n Lower max difference:" ,
286+ difference $ lwr_diff ,
287+ " \n Upper max difference:" ,
288+ difference $ fit_diff
289+ )
290+ } else {
291+ message <- paste0(
292+ message ,
293+ " \n All results are within the difference threshold"
294+ )
295+ }
296+ results <- list ()
297+ results $ model_call <- model $ call
298+ results $ raw_results <- raw_results
299+ results $ message <- message
300+ results $ alert <- alert
301+ structure(results , class = c(" tidypredict_test" , " list" ))
302+ }
303+
189304# ' @export
190305tidypredict_test.xgb.Booster <- function (
191306 model ,
0 commit comments