Model building

Prerequisites

library(tidyverse)
library(modelr)
options(na.action = na.warn)

library(nycflights13)
library(lubridate)

Why are low quality diamonds more expensive?

In Lecture 4 we’ve seen a surprising relationship between the quality of diamonds and their price: low quality diamonds (poor cuts, bad colours, and inferior clarity) have higher prices!

ggplot(diamonds, aes(cut, price)) + geom_boxplot()

ggplot(diamonds, aes(color, price)) + geom_boxplot()

ggplot(diamonds, aes(clarity, price)) + geom_boxplot()

The worst diamond color is J, and the worst clarity is I1.

Price and carat

Confounding

In statistics, a confounder (also confounding variable, confounding factor or lurking variable) is a variable that influences both the dependent variable and independent variable causing a spurious association.

Wikipedia

What would be the single most important factor for determining the price of the diamond?

ggplot(diamonds, aes(carat, price)) + 
  geom_hex(bins = 50)

Lower quality diamonds tend to be larger.

We can make it easier to see how the other attributes of a diamond affect its relative price by fitting a model to separate out the effect of carat.

Some tweaks to the diamonds dataset for clarity:

  1. Focus on diamonds smaller than 2.5 carats (99.7% of the data)
  2. Log-transform the carat and price variables.
diamonds2 <- diamonds %>% 
  filter(carat <= 2.5) %>% 
  mutate(lprice = log2(price), lcarat = log2(carat))
ggplot(diamonds2, aes(lcarat, lprice)) + 
  geom_hex(bins = 50)

The log-transformation makes the pattern linear.

Remove the strong linear pattern:

mod_diamond <- lm(lprice ~ lcarat, data = diamonds2)

Visualise the fitted model after undoing the log transformation:

grid <- diamonds2 %>% 
  data_grid(carat = seq_range(carat, 20)) %>% 
  mutate(lcarat = log2(carat)) %>% 
  add_predictions(mod_diamond, "lprice") %>% 
  mutate(price = 2 ^ lprice)

ggplot(diamonds2, aes(carat, price)) + 
  geom_hex(bins = 50) + 
  geom_line(data = grid, colour = "red", size = 1)

Now look at the residuals to see if we’ve successfully removed the strong linear pattern:

diamonds2 <- diamonds2 %>% 
  add_residuals(mod_diamond, "lresid")

ggplot(diamonds2, aes(lcarat, lresid)) + 
  geom_hex(bins = 50)

Re-do our motivating plots using those residuals instead of price:

ggplot(diamonds2, aes(cut, lresid)) + geom_boxplot()

ggplot(diamonds2, aes(color, lresid)) + geom_boxplot()

ggplot(diamonds2, aes(clarity, lresid)) + geom_boxplot()

Now, as the quality of the diamond increases, so too does its relative price, as expected.

Interpreting the y axis

  • A residual of -1 indicates that lprice was 1 unit lower than a prediction based solely on its weight.

  • Points with a value of -1 are half the expected price (we used log2()), and residuals with value 1 are twice the predicted price.

A more complicated model

mod_diamond2 <- lm(lprice ~ lcarat + color + cut + clarity, data = diamonds2)

We can plot them individually in four plots:

grid <- diamonds2 %>% 
  data_grid(cut, .model = mod_diamond2) %>% 
  add_predictions(mod_diamond2)
grid
ggplot(grid, aes(cut, pred)) + 
  geom_point()

grid <- diamonds2 %>% 
  data_grid(color, .model = mod_diamond2) %>% 
  add_predictions(mod_diamond2)
grid
ggplot(grid, aes(color, pred)) + 
  geom_point()

Residual plot

diamonds2 <- diamonds2 %>% 
  add_residuals(mod_diamond2, "lresid2")

ggplot(diamonds2, aes(lcarat, lresid2)) + 
  geom_hex(bins = 50)

This plot indicates that there are some diamonds with quite large residuals - remember a residual of 2 indicates that the diamond is 4x the price that we expected.

It’s often useful to look at unusual values individually:

diamonds2 %>% 
  filter(abs(lresid2) > 1) %>% 
  add_predictions(mod_diamond2) %>% 
  mutate(pred = round(2 ^ pred)) %>% 
  select(price, pred, carat:table, x:z) %>% 
  arrange(price)  # sort a variable in descending order

Nothing really comes out, but if there are mistakes in the data, this could be an opportunity to buy diamonds that have been priced low incorrectly.

What affects the number of daily flights?

Number of flights that leave NYC per day: 365 rows and 2 columns

Count the number of flights per day:

daily <- flights %>% 
  mutate(date = make_date(year, month, day)) %>% 
  group_by(date) %>% 
  summarise(n = n())
## `summarise()` ungrouping output (override with `.groups` argument)
daily
ggplot(daily, aes(date, n)) + 
  geom_line()

Day of week effect

Distribution of flight numbers by day-of-week:

daily <- daily %>% 
  mutate(wday = wday(date, label = TRUE))
ggplot(daily, aes(wday, n)) + 
  geom_boxplot()

There are fewer flights on weekends, esp. Saturdays (why?).

Fit a model to remove this strong pattern:

mod <- lm(n ~ wday, data = daily)

grid <- daily %>% 
  data_grid(wday) %>% 
  add_predictions(mod, "n")

ggplot(daily, aes(wday, n)) + 
  geom_boxplot() +
  geom_point(data = grid, colour = "red", size = 4)

Next we compute and visualise the residuals:

daily <- daily %>% 
  add_residuals(mod)
daily %>% 
  ggplot(aes(date, resid)) + 
  geom_ref_line(h = 0) + 
  geom_line()

Now we are seeing the deviation from the expected number of flights, given the day of week. Compare this with above.

Some of the subtler patterns still remain:

  1. Still a strong regular pattern starting from June. To see why, drawing a plot with one line for each day of the week:

    ggplot(daily, aes(date, resid, colour = wday)) + 
      geom_ref_line(h = 0) + 
      geom_line()

    Our model fails to accurately predict the number of flights on Saturday: during summer there are more flights than we expect, and during Fall there are fewer.

  2. There are some days with far fewer flights than expected:

    daily %>% 
      filter(resid < -100)

    If you’re familiar with American public holidays, you might spot New Year’s day, July 4th, Thanksgiving and Christmas. You’ll work on those in one homework.

  3. There seems to be some smoother long term trend over the course of a year.

    daily %>% 
      ggplot(aes(date, resid)) + 
      geom_ref_line(h = 0) + 
      geom_line(colour = "grey50") + 
      geom_smooth(se = FALSE, span = 0.20)
    ## `geom_smooth()` using method = 'loess' and formula 'y ~ x'

    There are fewer flights in January (and December), and more in summer (May-Sep). We can’t do much with this pattern quantitatively, because we only have a single year of data. But we can use our domain knowledge to brainstorm potential explanations.

Seasonal Saturday effect

Go back to the raw numbers, focussing on Saturdays:

daily %>% 
  filter(wday == "Sat") %>% 
  ggplot(aes(date, n)) + 
    geom_point() + 
    geom_line() +
    scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")

This pattern is probably caused by summer holidays: many people go on holiday in the summer, and people don’t mind travelling on Saturdays for vacation. Looking at this plot, we might guess that summer holidays are from early June to late August. That seems to line up fairly well with the New York state’s school terms: summer break in 2013 was Jun 26–Sep 9.

(Then why are there more Saturday flights in the Spring than the Fall?)

Create a “term” variable that roughly captures the three school terms, and check with a plot:

term <- function(date) {
  cut(date, 
    breaks = ymd(20130101, 20130605, 20130825, 20140101),
    labels = c("spring", "summer", "fall") 
  )
}

daily <- daily %>% 
  mutate(term = term(date)) 

daily %>% 
  filter(wday == "Sat") %>% 
  ggplot(aes(date, n, colour = term)) +
  geom_point(alpha = 1/3) + 
  geom_line() +
  scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")

Distribution of flight numbers by day-of-week, for each term:

daily %>% 
  ggplot(aes(wday, n, colour = term)) +
    geom_boxplot()

It looks like there is significant variation across the terms, so fitting a separate day of week effect for each term is reasonable.

mod1 <- lm(n ~ wday, data = daily)
mod2 <- lm(n ~ wday * term, data = daily)

daily %>% 
  gather_residuals(without_term = mod1, with_term = mod2) %>% 
  ggplot(aes(date, resid, colour = model)) +
    geom_line(alpha = 0.75)

This improves our model, but we can still spot the day of week effect.

Overlay the predictions from the model on to the raw data:

grid <- daily %>% 
  data_grid(wday, term) %>% 
  add_predictions(mod2, "n")

ggplot(daily, aes(wday, n)) +
  geom_boxplot() + 
  geom_point(data = grid, colour = "red") + 
  facet_wrap(~ term)

We have a lot of big outliers (especially in the fall); use a model that is robust to the effect of outliers:

mod3 <- MASS::rlm(n ~ wday * term, data = daily)

daily %>% 
  gather_residuals(non_robust = mod2, robust = mod3) %>% 
  ggplot(aes(date, resid, colour = model)) +
    geom_line(alpha = 0.75)

Now the day of week effect is quite small. It’s now much easier to see the positive and negative outliers, and the long-term trend.

Time of year: an alternative approach

Instead of applying our domain knowledge (how the US school term affects travel) to improve the linear model, we could use a more flexible model and allow that to capture the pattern we’re interested in:

library(splines)
mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)

daily %>% 
  data_grid(wday, date = seq_range(date, n = 13)) %>% 
  add_predictions(mod) %>% 
  ggplot(aes(date, pred, colour = wday)) + 
    geom_line() +
    geom_point()

We see a strong pattern in the numbers of Saturday flights. This is reassuring, because we also saw that pattern in the raw data. It’s a good sign when you get the same signal from different approaches.