library(tidyverse)
library(modelr)
options(na.action = na.warn)
library(nycflights13)
library(lubridate)
In Lecture 4 we’ve seen a surprising relationship between the quality of diamonds and their price: low quality diamonds (poor cuts, bad colours, and inferior clarity) have higher prices!
ggplot(diamonds, aes(cut, price)) + geom_boxplot()
ggplot(diamonds, aes(color, price)) + geom_boxplot()
ggplot(diamonds, aes(clarity, price)) + geom_boxplot()
The worst diamond color is J, and the worst clarity is I1.
In statistics, a confounder (also confounding variable, confounding factor or lurking variable) is a variable that influences both the dependent variable and independent variable causing a spurious association.
What would be the single most important factor for determining the price of the diamond?
ggplot(diamonds, aes(carat, price)) +
geom_hex(bins = 50)
Lower quality diamonds tend to be larger.
We can make it easier to see how the other attributes of a diamond affect its relative price by fitting a model to separate out the effect of carat.
Some tweaks to the diamonds dataset for clarity:
diamonds2 <- diamonds %>%
filter(carat <= 2.5) %>%
mutate(lprice = log2(price), lcarat = log2(carat))
ggplot(diamonds2, aes(lcarat, lprice)) +
geom_hex(bins = 50)
The log-transformation makes the pattern linear.
Remove the strong linear pattern:
mod_diamond <- lm(lprice ~ lcarat, data = diamonds2)
Visualise the fitted model after undoing the log transformation:
grid <- diamonds2 %>%
data_grid(carat = seq_range(carat, 20)) %>%
mutate(lcarat = log2(carat)) %>%
add_predictions(mod_diamond, "lprice") %>%
mutate(price = 2 ^ lprice)
ggplot(diamonds2, aes(carat, price)) +
geom_hex(bins = 50) +
geom_line(data = grid, colour = "red", size = 1)
Now look at the residuals to see if we’ve successfully removed the strong linear pattern:
diamonds2 <- diamonds2 %>%
add_residuals(mod_diamond, "lresid")
ggplot(diamonds2, aes(lcarat, lresid)) +
geom_hex(bins = 50)
Re-do our motivating plots using those residuals instead of price:
ggplot(diamonds2, aes(cut, lresid)) + geom_boxplot()
ggplot(diamonds2, aes(color, lresid)) + geom_boxplot()
ggplot(diamonds2, aes(clarity, lresid)) + geom_boxplot()
Now, as the quality of the diamond increases, so too does its relative price, as expected.
y axisA residual of -1 indicates that lprice was 1 unit lower than a prediction based solely on its weight.
Points with a value of -1 are half the expected price (we used log2()), and residuals with value 1 are twice the predicted price.
mod_diamond2 <- lm(lprice ~ lcarat + color + cut + clarity, data = diamonds2)
We can plot them individually in four plots:
grid <- diamonds2 %>%
data_grid(cut, .model = mod_diamond2) %>%
add_predictions(mod_diamond2)
grid
## # A tibble: 5 x 5
## cut lcarat color clarity pred
## <ord> <dbl> <chr> <chr> <dbl>
## 1 Fair -0.515 G VS2 11.2
## 2 Good -0.515 G VS2 11.3
## 3 Very Good -0.515 G VS2 11.4
## 4 Premium -0.515 G VS2 11.4
## 5 Ideal -0.515 G VS2 11.4
ggplot(grid, aes(cut, pred)) +
geom_point()
grid <- diamonds2 %>%
data_grid(color, .model = mod_diamond2) %>%
add_predictions(mod_diamond2)
grid
## # A tibble: 7 x 5
## color lcarat cut clarity pred
## <ord> <dbl> <chr> <chr> <dbl>
## 1 D -0.515 Premium VS2 11.6
## 2 E -0.515 Premium VS2 11.6
## 3 F -0.515 Premium VS2 11.5
## 4 G -0.515 Premium VS2 11.4
## 5 H -0.515 Premium VS2 11.3
## 6 I -0.515 Premium VS2 11.1
## 7 J -0.515 Premium VS2 10.9
ggplot(grid, aes(color, pred)) +
geom_point()
diamonds2 <- diamonds2 %>%
add_residuals(mod_diamond2, "lresid2")
ggplot(diamonds2, aes(lcarat, lresid2)) +
geom_hex(bins = 50)
This plot indicates that there are some diamonds with quite large residuals - remember a residual of 2 indicates that the diamond is 4x the price that we expected.
It’s often useful to look at unusual values individually:
diamonds2 %>%
filter(abs(lresid2) > 1) %>%
add_predictions(mod_diamond2) %>%
mutate(pred = round(2 ^ pred)) %>%
select(price, pred, carat:table, x:z) %>%
arrange(price) # sort a variable in descending order
## # A tibble: 16 x 11
## price pred carat cut color clarity depth table x y z
## <int> <dbl> <dbl> <ord> <ord> <ord> <dbl> <dbl> <dbl> <dbl> <dbl>
## 1 1013 264 0.25 Fair F SI2 54.4 64 4.3 4.23 2.32
## 2 1186 284 0.25 Premium G SI2 59 60 5.33 5.28 3.12
## 3 1186 284 0.25 Premium G SI2 58.8 60 5.33 5.28 3.12
## 4 1262 2644 1.03 Fair E I1 78.2 54 5.72 5.59 4.42
## 5 1415 639 0.35 Fair G VS2 65.9 54 5.57 5.53 3.66
## 6 1415 639 0.35 Fair G VS2 65.9 54 5.57 5.53 3.66
## 7 1715 576 0.32 Fair F VS2 59.6 60 4.42 4.34 2.61
## 8 1776 412 0.290 Fair F SI1 55.8 60 4.48 4.41 2.48
## 9 2160 314 0.34 Fair F I1 55.8 62 4.72 4.6 2.6
## 10 2366 774 0.3 Very Good D VVS2 60.6 58 4.33 4.35 2.63
## 11 3360 1373 0.51 Premium F SI1 62.7 62 5.09 4.96 3.15
## 12 3807 1540 0.61 Good F SI2 62.5 65 5.36 5.29 3.33
## 13 3920 1705 0.51 Fair F VVS2 65.4 60 4.98 4.9 3.23
## 14 4368 1705 0.51 Fair F VVS2 60.7 66 5.21 5.11 3.13
## 15 10011 4048 1.01 Fair D SI2 64.6 58 6.25 6.2 4.02
## 16 10470 23622 2.46 Premium E SI2 59.7 59 8.82 8.76 5.25
Nothing really comes out, but if there are mistakes in the data, this could be an opportunity to buy diamonds that have been priced low incorrectly.
Count the number of flights per day:
daily <- flights %>%
mutate(date = make_date(year, month, day)) %>%
group_by(date) %>%
summarise(n = n())
daily
## # A tibble: 365 x 2
## date n
## <date> <int>
## 1 2013-01-01 842
## 2 2013-01-02 943
## 3 2013-01-03 914
## 4 2013-01-04 915
## 5 2013-01-05 720
## 6 2013-01-06 832
## 7 2013-01-07 933
## 8 2013-01-08 899
## 9 2013-01-09 902
## 10 2013-01-10 932
## # ... with 355 more rows
ggplot(daily, aes(date, n)) +
geom_line()
Distribution of flight numbers by day-of-week:
daily <- daily %>%
mutate(wday = wday(date, label = TRUE))
ggplot(daily, aes(wday, n)) +
geom_boxplot()
There are fewer flights on weekends, esp. Saturdays (why?).
Fit a model to remove this strong pattern:
mod <- lm(n ~ wday, data = daily)
grid <- daily %>%
data_grid(wday) %>%
add_predictions(mod, "n")
ggplot(daily, aes(wday, n)) +
geom_boxplot() +
geom_point(data = grid, colour = "red", size = 4)
Next we compute and visualise the residuals:
daily <- daily %>%
add_residuals(mod)
daily %>%
ggplot(aes(date, resid)) +
geom_ref_line(h = 0) +
geom_line()
Now we are seeing the deviation from the expected number of flights, given the day of week. Compare this with above.
Some of the subtler patterns still remain:
Still a strong regular pattern starting from June. To see why, drawing a plot with one line for each day of the week:
ggplot(daily, aes(date, resid, colour = wday)) +
geom_ref_line(h = 0) +
geom_line()
Our model fails to accurately predict the number of flights on Saturday: during summer there are more flights than we expect, and during Fall there are fewer.
There are some days with far fewer flights than expected:
daily %>%
filter(resid < -100)
## # A tibble: 11 x 4
## date n wday resid
## <date> <int> <ord> <dbl>
## 1 2013-01-01 842 Tue -109.
## 2 2013-01-20 786 Sun -105.
## 3 2013-05-26 729 Sun -162.
## 4 2013-07-04 737 Thu -229.
## 5 2013-07-05 822 Fri -145.
## 6 2013-09-01 718 Sun -173.
## 7 2013-11-28 634 Thu -332.
## 8 2013-11-29 661 Fri -306.
## 9 2013-12-24 761 Tue -190.
## 10 2013-12-25 719 Wed -244.
## 11 2013-12-31 776 Tue -175.
If you’re familiar with American public holidays, you might spot New Year’s day, July 4th, Thanksgiving and Christmas. You’ll work on those in one homework.
There seems to be some smoother long term trend over the course of a year.
daily %>%
ggplot(aes(date, resid)) +
geom_ref_line(h = 0) +
geom_line(colour = "grey50") +
geom_smooth(se = FALSE, span = 0.20)
## `geom_smooth()` using method = 'loess' and formula 'y ~ x'
There are fewer flights in January (and December), and more in summer (May-Sep). We can’t do much with this pattern quantitatively, because we only have a single year of data. But we can use our domain knowledge to brainstorm potential explanations.
Go back to the raw numbers, focussing on Saturdays:
daily %>%
filter(wday == "Sat") %>%
ggplot(aes(date, n)) +
geom_point() +
geom_line() +
scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")
This pattern is probably caused by summer holidays: many people go on holiday in the summer, and people don’t mind travelling on Saturdays for vacation. Looking at this plot, we might guess that summer holidays are from early June to late August. That seems to line up fairly well with the New York state’s school terms: summer break in 2013 was Jun 26–Sep 9.
(Then why are there more Saturday flights in the Spring than the Fall?)
Create a “term” variable that roughly captures the three school terms, and check with a plot:
term <- function(date) {
cut(date,
breaks = ymd(20130101, 20130605, 20130825, 20140101),
labels = c("spring", "summer", "fall")
)
}
daily <- daily %>%
mutate(term = term(date))
daily %>%
filter(wday == "Sat") %>%
ggplot(aes(date, n, colour = term)) +
geom_point(alpha = 1/3) +
geom_line() +
scale_x_date(NULL, date_breaks = "1 month", date_labels = "%b")
Distribution of flight numbers by day-of-week, for each term:
daily %>%
ggplot(aes(wday, n, colour = term)) +
geom_boxplot()
It looks like there is significant variation across the terms, so fitting a separate day of week effect for each term is reasonable.
mod1 <- lm(n ~ wday, data = daily)
mod2 <- lm(n ~ wday * term, data = daily)
daily %>%
gather_residuals(without_term = mod1, with_term = mod2) %>%
ggplot(aes(date, resid, colour = model)) +
geom_line(alpha = 0.75)
This improves our model, but we can still spot the day of week effect.
Overlay the predictions from the model on to the raw data:
grid <- daily %>%
data_grid(wday, term) %>%
add_predictions(mod2, "n")
ggplot(daily, aes(wday, n)) +
geom_boxplot() +
geom_point(data = grid, colour = "red") +
facet_wrap(~ term)
We have a lot of big outliers (especially in the fall); use a model that is robust to the effect of outliers:
mod3 <- MASS::rlm(n ~ wday * term, data = daily)
daily %>%
gather_residuals(non_robust = mod2, robust = mod3) %>%
ggplot(aes(date, resid, colour = model)) +
geom_line(alpha = 0.75)
Now the day of week effect is quite small. It’s now much easier to see the positive and negative outliers, and the long-term trend.
Instead of applying our domain knowledge (how the US school term affects travel) to improve the linear model, we could use a more flexible model and allow that to capture the pattern we’re interested in:
library(splines)
mod <- MASS::rlm(n ~ wday * ns(date, 5), data = daily)
daily %>%
data_grid(wday, date = seq_range(date, n = 13)) %>%
add_predictions(mod) %>%
ggplot(aes(date, pred, colour = wday)) +
geom_line() +
geom_point()
We see a strong pattern in the numbers of Saturday flights. This is reassuring, because we also saw that pattern in the raw data. It’s a good sign when you get the same signal from different approaches.