Árvores de Decisão

Bruna Wundervald
brunadaviesw at gmail.com
Departamento de Estatística - UFPR

Introdução

Árvores de Regressão

Em resumo, para a construção de árvores de regressão são necessários 2 passos:

  1. O espaço da variável preditora é dividido em \(J\) regiões distintas, de forma que ela não se sobreponham, \(R_1, R_2...R_J\).

  2. Para cada observação que estiver na região \(R_j\), é feita a mesma predição, que na verdade é apenas a média dos pontos presentes em \(R_j\).

Assim, após obtidas as regiões, as predições das observações delas serão a média dos pontos em cada região. Por exemplo, caso seja obtida a uma primeira região \(R_1\), e a média dos pontos nesta região seja 10, os valores preditos das observações que cairem nela valerão 10, e assim sucessivamente. Mas como encontrar as regiões ótimas?

Partições ótimas

Na prática, é computacionalmente inviável realizar todas as partições possíveis em \(J\) caixas. Por esse motivo, é considerada a abordagem da divisão binária recursiva. Começando no topo da árvore, onde todas as observações pertencem a uma mesma região, são feitas divisões sucessivas no espaço das preditoras, sendo que cada divisão gera dois novos ramos abaixo da árvore.

Em cada passo, a divisão feita considera apenas o estado atual da árvore, e não situações futuras. Estas divisões selecionam, primeiramente, a variável preditora \(X_j\), e o ponto de corte que leva à maior reduçãos na SQR. A seguir, as divisões continuam sendo feitas, mas sempre partindo das anteriores, e não mais da região que contém todos os pontos. O processo continua até algum critério de parada ser atingido.

Podagem de árvores

  • Uma estratégia adotada quando são construídas árvores de regressão é chegar em uma grande árvore e ir podando. O objetivo é selecionar uma sub-árvore que dê o menor erro na amostra de validação. Como fazer isso para todas as sub-aŕvores possíveis não é prático, um subconjunto delas é selecionado. É considerada uma sequência de árvores indexadas por um parâmetro de tuning \(\alpha\). Cada valor de \(\alpha\) corresponde a uma sub-árvore $ T T_0$ tal que:

\[ \sum_{m = 1}^{|T|} \sum_{i: x_i \in R_m} (y_i - \hat y_{R_m})^2 + \alpha |T| \]

Seja o menor possível. Basicamente, o \(\alpha\) controla o trade-off entre a complexidade da sub-árvore e seu ajuste aos dados de treinamento. Isto é, quando \(\alpha\) é grande, existe uma penalização pela sua quantidade de nós terminais.

set.seed(20172)
# Carregamento de pacotes
library(tidyverse)
## Loading tidyverse: ggplot2
## Loading tidyverse: tibble
## Loading tidyverse: tidyr
## Loading tidyverse: readr
## Loading tidyverse: purrr
## Loading tidyverse: dplyr
## Conflicts with tidy packages ----------------------------------------------
## filter(): dplyr, stats
## lag():    dplyr, stats
library(ggplot2)
library(plyr)
## -------------------------------------------------------------------------
## You have loaded plyr after dplyr - this is likely to cause problems.
## If you need functions from both plyr and dplyr, please load plyr first, then dplyr:
## library(plyr); library(dplyr)
## -------------------------------------------------------------------------
## 
## Attaching package: 'plyr'
## The following objects are masked from 'package:dplyr':
## 
##     arrange, count, desc, failwith, id, mutate, rename, summarise,
##     summarize
## The following object is masked from 'package:purrr':
## 
##     compact
library(dplyr)
library(gridExtra)
## 
## Attaching package: 'gridExtra'
## The following object is masked from 'package:dplyr':
## 
##     combine
# Leitura e organização da base:
# 1. Leitura
# 2. Filtra apenas quem mora no centro e tem renda maior que 0
# 3. Seleciona apenas as colunas de interesse
# 4. Transforma algumas colunas em fator e decide quem fará parte da 
# amostra de treino e da de teste

db <- read.csv("/home/bruna/GIT/Machine Learning/Dados/dados2.txt", 
               header = TRUE, 
               sep = "\t",
               encoding = "UTF-8") %>% 
  filter(val_desp_aluguel_fam >  0) %>% 
  dplyr::select(c("endereco", "val_desp_aluguel_fam", 
             "qtd_comodos_domic_fam", "qtd_comodos_dormitorio_fam",
             "cod_material_piso_fam", "cod_material_domic_fam",
             "cod_iluminacao_domic_fam",
             "qtd_pessoas_domic_fam")) %>% 
  dplyr::mutate(cod_iluminacao_domic_fam = factor(cod_iluminacao_domic_fam),
                cod_material_domic_fam = factor(cod_material_domic_fam),
                cod_material_piso_fam = factor(cod_material_piso_fam))
dim(db)
## [1] 14100     8
getFactor <- function(x) {
   x <- na.omit(x)
   tb <- table(x)
   nm <- names(tb)[tb == max(tb)]
   return(sample(nm, 1))
}

db.aj <- db %>%  dplyr::group_by(factor(endereco)) %>% 
  dplyr::summarise(total.count = n(), 
                   m.aluguel = mean(val_desp_aluguel_fam, na.rm=TRUE),
                   max.dorm = max(qtd_comodos_dormitorio_fam, na.rm=TRUE),
                   max.com = max(qtd_comodos_domic_fam, na.rm=TRUE),
                   max.pessoas = max(qtd_pessoas_domic_fam, na.rm = TRUE),
                   piso = getFactor(cod_material_piso_fam),
                   material = getFactor(cod_material_domic_fam),
                   ilum = getFactor(cod_iluminacao_domic_fam)) %>%  
  filter(m.aluguel < 3000, max.com > 0) %>% 
  dplyr::mutate(part = ifelse(runif(3945) > 0.3, "treino", "teste"))

names(db.aj)
##  [1] "factor(endereco)" "total.count"      "m.aluguel"       
##  [4] "max.dorm"         "max.com"          "max.pessoas"     
##  [7] "piso"             "material"         "ilum"            
## [10] "part"
dim(db.aj)
## [1] 3945   10
# Descritiva das variáveis
p1 <- ggplot(data = db.aj, 
       aes(y = m.aluguel, x = factor(material))) +
  geom_boxplot(aes(fill = factor(material)), colour = "ivory4") +
  xlab("Tipo de material") +
  ylab("Aluguel") +
  guides(fill=FALSE)


p2 <- ggplot(data = db.aj, 
             aes(y = m.aluguel, x = factor(piso))) +
  geom_boxplot(aes(fill = factor(piso)), colour = "ivory4") +
  xlab("Tipo de piso") +
  ylab("Aluguel") +
  guides(fill=FALSE)

p3 <- ggplot(data = db.aj, 
             aes(y = m.aluguel, x = factor(ilum))) +
  geom_boxplot(aes(fill = factor(ilum)), colour = "ivory4") +
  xlab("Tipo de iluminação") +
  ylab("Aluguel") +
  guides(fill=FALSE)

p4 <- ggplot(data = db.aj, 
             aes(y = m.aluguel, x = max.dorm)) +
  geom_boxplot(aes(fill = factor(max.dorm)), colour = "ivory4") +
  xlab("Quantidade de dormitórios") +
  ylab("Aluguel") +
  guides(fill=FALSE)

p5 <- ggplot(data = db.aj, 
             aes(y = m.aluguel, x = factor(max.com))) +
  geom_boxplot(aes(fill = factor(max.com)), colour = "ivory4") +
  xlab("Quantidade de dormitórios") +
  ylab("Aluguel") +
  guides(fill=FALSE)

p6 <- ggplot(data = db.aj, 
             aes(x = 1, y = m.aluguel)) +
  geom_boxplot(fill = "tomato", colour = "ivory4") +
  ylab("Aluguel") +
  xlab("") + 
  guides(fill=FALSE)


multiplot <- function(..., plotlist=NULL, file, cols=1, layout=NULL) {
  library(grid)

  # Make a list from the ... arguments and plotlist
  plots <- c(list(...), plotlist)

  numPlots = length(plots)

  # If layout is NULL, then use 'cols' to determine layout
  if (is.null(layout)) {
    # Make the panel
    # ncol: Number of columns of plots
    # nrow: Number of rows needed, calculated from # of cols
    layout <- matrix(seq(1, cols * ceiling(numPlots/cols)),
                    ncol = cols, nrow = ceiling(numPlots/cols))
  }

 if (numPlots==1) {
    print(plots[[1]])

  } else {
    # Set up the page
    grid.newpage()
    pushViewport(viewport(layout = grid.layout(nrow(layout), ncol(layout))))

    # Make each plot, in the correct location
    for (i in 1:numPlots) {
      # Get the i,j matrix positions of the regions that contain this subplot
      matchidx <- as.data.frame(which(layout == i, arr.ind = TRUE))

      print(plots[[i]], vp = viewport(layout.pos.row = matchidx$row,
                                      layout.pos.col = matchidx$col))
    }
  }
}

multiplot(p1, p2, p3, p4, p5, p6, cols = 3) 
## Warning: Removed 3 rows containing non-finite values (stat_boxplot).

#-------------------------------------------------------------
# Àrvores de Regressão
#-------------------------------------------------------------
library(rpart)
# Separação em treino e teste
db.teste <- db.aj %>% filter(part == "teste")
db.treino <- db.aj %>% filter(part == "treino")


t1 <- rpart(m.aluguel ~ max.dorm + max.com + max.pessoas + 
             piso + material + ilum, data = db.treino)
summary(t1)
## Call:
## rpart(formula = m.aluguel ~ max.dorm + max.com + max.pessoas + 
##     piso + material + ilum, data = db.treino)
##   n= 2773 
## 
##           CP nsplit rel error    xerror       xstd
## 1 0.09431424      0 1.0000000 1.0008109 0.03617327
## 2 0.01842549      1 0.9056858 0.9169132 0.03295848
## 3 0.01696092      2 0.8872603 0.9123584 0.03228755
## 4 0.01331811      3 0.8702994 0.8941516 0.03208256
## 5 0.01000000      4 0.8569812 0.8827626 0.03156282
## 
## Variable importance
##     max.com    max.dorm    material max.pessoas        piso        ilum 
##          48          31           8           8           3           2 
## 
## Node number 1: 2773 observations,    complexity param=0.09431424
##   mean=465.7019, MSE=26175.53 
##   left son=2 (1475 obs) right son=3 (1298 obs)
##   Primary splits:
##       max.com     < 4.5 to the left,  improve=0.09431424, (0 missing)
##       max.dorm    < 1.5 to the left,  improve=0.08548159, (3 missing)
##       piso        splits as  LLLLRRR, improve=0.02579073, (0 missing)
##       ilum        splits as  RLL,     improve=0.02081531, (0 missing)
##       max.pessoas < 2.5 to the left,  improve=0.02065213, (3 missing)
##   Surrogate splits:
##       max.dorm    < 1.5 to the left,  agree=0.809, adj=0.591, (0 split)
##       max.pessoas < 3.5 to the left,  agree=0.615, adj=0.178, (0 split)
##       ilum        splits as  RLR,     agree=0.555, adj=0.049, (0 split)
##       piso        splits as  LLLRLRL, agree=0.536, adj=0.009, (0 split)
##       material    splits as  LLRLLLL, agree=0.532, adj=0.001, (0 split)
## 
## Node number 2: 1475 observations,    complexity param=0.01696092
##   mean=419.0921, MSE=19591.43 
##   left son=4 (622 obs) right son=5 (853 obs)
##   Primary splits:
##       max.com  < 3.5 to the left,  improve=0.04260264, (0 missing)
##       max.dorm < 1.5 to the left,  improve=0.03409431, (1 missing)
##       piso     splits as  LLLLR-L, improve=0.02699821, (0 missing)
##       ilum     splits as  RL-,     improve=0.01958444, (0 missing)
##       material splits as  RLLRRLL, improve=0.01349931, (0 missing)
##   Surrogate splits:
##       max.dorm    < 1.5 to the left,  agree=0.685, adj=0.254, (0 split)
##       max.pessoas < 1.5 to the left,  agree=0.616, adj=0.090, (0 split)
##       ilum        splits as  RL-,     agree=0.599, adj=0.050, (0 split)
##       piso        splits as  RLRRR-L, agree=0.584, adj=0.014, (0 split)
##       material    splits as  RLRRRRR, agree=0.579, adj=0.002, (0 split)
## 
## Node number 3: 1298 observations,    complexity param=0.01842549
##   mean=518.6677, MSE=28383.36 
##   left son=6 (243 obs) right son=7 (1055 obs)
##   Primary splits:
##       material splits as  RRL-RLL, improve=0.036301600, (0 missing)
##       max.dorm < 2.5 to the left,  improve=0.023621280, (2 missing)
##       piso     splits as  LLLRRRR, improve=0.022195930, (0 missing)
##       max.com  < 5.5 to the left,  improve=0.011894450, (0 missing)
##       ilum     splits as  RLL,     improve=0.002974448, (0 missing)
##   Surrogate splits:
##       piso splits as  RRLLRRR, agree=0.884, adj=0.379, (0 split)
## 
## Node number 4: 622 observations
##   mean=385.2598, MSE=17468.4 
## 
## Node number 5: 853 observations
##   mean=443.7622, MSE=19696.27 
## 
## Node number 6: 243 observations
##   mean=451.7843, MSE=20654.93 
## 
## Node number 7: 1055 observations,    complexity param=0.01331811
##   mean=534.073, MSE=28895.77 
##   left son=14 (877 obs) right son=15 (178 obs)
##   Primary splits:
##       max.dorm    < 2.5 to the left,  improve=0.031170700, (2 missing)
##       max.com     < 5.5 to the left,  improve=0.012174800, (0 missing)
##       piso        splits as  LLLRRRR, improve=0.010626400, (0 missing)
##       ilum        splits as  RLL,     improve=0.004787784, (0 missing)
##       max.pessoas < 3.5 to the left,  improve=0.001928826, (2 missing)
##   Surrogate splits:
##       max.com     < 5.5 to the left,  agree=0.855, adj=0.140, (2 split)
##       max.pessoas < 8.5 to the left,  agree=0.833, adj=0.011, (0 split)
##       material    splits as  LL--R--, agree=0.832, adj=0.006, (0 split)
## 
## Node number 14: 877 observations
##   mean=520.4357, MSE=25611.22 
## 
## Node number 15: 178 observations
##   mean=601.2635, MSE=39647.79
plot(t1, uniform = FALSE, branch = 1, compress = FALSE, nspace,
     margin = 0)
text(t1, use.n = TRUE)

printcp(t1)
## 
## Regression tree:
## rpart(formula = m.aluguel ~ max.dorm + max.com + max.pessoas + 
##     piso + material + ilum, data = db.treino)
## 
## Variables actually used in tree construction:
## [1] material max.com  max.dorm
## 
## Root node error: 72584732/2773 = 26176
## 
## n= 2773 
## 
##         CP nsplit rel error  xerror     xstd
## 1 0.094314      0   1.00000 1.00081 0.036173
## 2 0.018425      1   0.90569 0.91691 0.032958
## 3 0.016961      2   0.88726 0.91236 0.032288
## 4 0.013318      3   0.87030 0.89415 0.032083
## 5 0.010000      4   0.85698 0.88276 0.031563