Skip to content

yhat(model) and yhat(model, newdata) return different results on the same data #3

@mkwiecinski

Description

@mkwiecinski

Hi,

Code below compares output from two calls to yhat function after model has been trained with randomForest. First is raw call, another additionally uses newdata parameter. When newdata is supplied with the same data that randomForest has been trained on, the output is different, although it should be the same.

For comparison, I also provide output directly from randomForest's votes object which apparently contains the same information as raw yhat call, therefore the problem appears to be only when newdata is supplied.

library(tidyverse)
#> -- Attaching packages ------------------------------------------------------------------------------------------------------------ tidyverse 1.2.1 --
#> <U+221A> ggplot2 2.2.1     <U+221A> purrr   0.2.4
#> <U+221A> tibble  1.4.1     <U+221A> dplyr   0.7.4
#> <U+221A> tidyr   0.7.2     <U+221A> stringr 1.2.0
#> <U+221A> readr   1.1.1     <U+221A> forcats 0.2.0
#> -- Conflicts --------------------------------------------------------------------------------------------------------------- tidyverse_conflicts() --
#> x dplyr::filter() masks stats::filter()
#> x dplyr::lag()    masks stats::lag()
library(DALEX2)
#> Welcome to DALEX2 (version: 0.9).
#> 
#> Dołączanie pakietu: 'DALEX2'
#> Następujący obiekt został zakryty z 'package:dplyr':
#> 
#>     explain
library(randomForest)
#> randomForest 4.6-12
#> Type rfNews() to see new features/changes/bug fixes.
#> 
#> Dołączanie pakietu: 'randomForest'
#> Następujący obiekt został zakryty z 'package:dplyr':
#> 
#>     combine
#> Następujący obiekt został zakryty z 'package:ggplot2':
#> 
#>     margin

X <- HR %>% select(-status, -gender)
Y <- HR %>% select(status) %>% unlist()
Y <- if_else(Y=="fired",1,0) %>% as.factor(.)

rf.model <- randomForest(x = X, y = Y, ntree = 500, localImp = TRUE)

f1 <- yhat(rf.model)
f2 <- yhat(rf.model, newdata=X)
f3 <- rf.model$votes

cbind(f1[,2],f2[,2],f3[,2]) %>% as_data_frame %>% head(.)
#> # A tibble: 6 x 3
#>       V1     V2     V3
#>    <dbl>  <dbl>  <dbl>
#> 1 0.872  0.938  0.872 
#> 2 0.982  0.994  0.982 
#> 3 0.983  0.994  0.983 
#> 4 0.764  0.866  0.764 
#> 5 0.0585 0.0220 0.0585
#> 6 0.624  0.794  0.624

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions