Regression function Figure

Data generation

# Load packages
library(ggplot2)
library(FNN)  # Used for k-nearest neighbours fit
library(ISLR)

## Note: if you get Error messages saying things like:
## Error in library(ggplot2): there is no package called 'ggplot2'
## you need to run the command:
## install.packages("ggplot2")
##
## Similarly for other packages.

# Define colour-blind-friendly colour palette
cbPalette <- c("#999999", "#E69F00", "#56B4E9", "#009E73", "#F0E442", "#0072B2", "#D55E00", "#CC79A7")

# Set seed for the random number generator
set.seed(4490)


# This is the true regression function
f.reg <- function(x) {
  1 + 0.1*(x - 2.5)^2
}

# Generate data
n <- 2000 # num observations
x <- rnorm(n, mean = 4, sd = 1.75)

y <- f.reg(x) + rnorm(n, sd = 0.25)

# Combine into data frame
df <- data.frame(x = x, y = y)

x.val <- 5
y.val <- f.reg(x.val)

The code above produced 2000 randomly generated realizations from the underlying model specified by the f.reg function. More precisely, the code starts by generating 2000 x-values from the Normal(4,σ2=1.752) distribution, and then forms y-values according to:

yi=1+0.1(xi2.5)2+ϵi

where the errors are distributed as ϵiNormal(0,σ2=0.252).

Resulting plots

qplot(x = x, y = y, alpha = I(0.5), colour = I(cbPalette[1])) + 
  stat_function(fun = f.reg, colour = cbPalette[2], lwd = 1.25) +
  geom_point(aes(x = x.val, y = y.val), size = 4, colour = cbPalette[4]) + 
  geom_vline(xintercept = 5, colour = cbPalette[4], lwd = 1)

Here’s a figure showing the true regression function, the observed data, and the linear regression (blue) and 50-nearest-neighbours (green) fits to the data.

# 3-NN regression model
knn.fit <- knn.reg(train = x, y = y, k = 50)
qplot(x = x, y = y, alpha = I(0.5), colour = I(cbPalette[1])) + 
  stat_function(fun = f.reg, colour = cbPalette[2], lwd = 1.5) +
  stat_smooth(method = "lm", se = FALSE, lwd = 1) +
  geom_line(aes(x = x, y = knn.fit$pred), colour = I(cbPalette[4]), lwd = 0.75)

Best linear approximation

Data generation

set.seed(14316)

# true regresion function (non-linear)
true.reg <- function(x) {
  (1 + 0.2*cos(1.5*x)) * (3*x + 1)
}

n <- 5000 # num observations for getting 'best linear approximation'
n.sub <- 200  # num observations for getting linear regression fit

We generate two sets of data for this example. First, we generate 5000 points for the purpose of figuring out the “best linear predictor” fL(x). We could derive fL(x) analytically (i.e., do a bunch of math to figure out what it should be), but here we’re cheating a little and just using a very large sample size to figure out what fL(x) should be.

To get a linear regression predictor f^(x), we generate just 200 and treat these as our observations.

# Generate data for figuring out f_L
x <- rnorm(n, mean = 3, sd = 1)
y <- true.reg(x) + rnorm(n, sd = 5)

# Generate data to feed into linear regression
x.sub <- rnorm(n.sub, mean = 3, sd = 1)
y.sub <- true.reg(x.sub) + rnorm(n.sub, sd = 5)

Linear regression predictor plot

qplot(x=x, y=y, colour = I(cbPalette[1]), alpha = I(0)) + 
  geom_point(aes(x = x.sub, y = y.sub), alpha = I(0.7), colour = I(cbPalette[1])) +
  stat_function(fun = true.reg, colour = cbPalette[2], lwd = 1.25) +
  stat_smooth(method = "lm", aes(x = x.sub, y = y.sub), se = FALSE, lwd = 1) +
  stat_smooth(method = "lm", se = FALSE, colour = cbPalette[7], lty = 2, lwd = 1) + 
  theme_bw()

Shown: True regression function (solid orange), Best linear predictor (dashed burnt orange), Linear regression fit based on observed points (solid blue)

Polynomial regression and step functions

qplot(data = Wage, x = age, y = wage, 
      xlab = "Age", ylab = "Wage", colour = I(cbPalette[1]), alpha = I(0.75)) +
      stat_smooth(method = "lm", formula = y ~ poly(x, 1), lwd = 1.25) + 
      theme_bw()

qplot(data = Wage, x = age, y = wage, 
      xlab = "Age", ylab = "Wage", colour = I(cbPalette[1]), alpha = I(0.75)) +
  stat_smooth(method = "lm", formula = y ~ poly(x, 2), lwd = 1.25) + 
  theme_bw()

qplot(data = Wage, x = age, y = wage, 
      xlab = "Age", ylab = "Wage", colour = I(cbPalette[1]), alpha = I(0.75)) +
  stat_smooth(method = "lm", formula = y ~ poly(x, 3), lwd = 1.25) + 
  theme_bw()

qplot(data = Wage, x = age, y = wage, 
      xlab = "Age", ylab = "Wage", colour = I(cbPalette[1]), alpha = I(0.75)) +
  stat_smooth(method = "lm", formula = y ~ poly(x, 4), lwd = 1.25) + 
  theme_bw()

qplot(data = Wage, x = age, y = wage, 
      xlab = "Age", ylab = "Wage", colour = I(cbPalette[1]), alpha = I(0.75)) +
  stat_smooth(method = "lm", formula = y ~ poly(x, 10), lwd = 1.25) + 
  theme_bw()

##### 
# Step functions
#####

qplot(data = Wage, x = age, y = wage, 
      xlab = "Age", ylab = "Wage", colour = I(cbPalette[1]), alpha = I(0.75)) +
  stat_smooth(method = "lm", formula = y ~ cut(x, breaks = c(-Inf, 65, Inf)), lwd = 1.25) + 
  theme_bw()

qplot(data = Wage, x = age, y = wage, 
      xlab = "Age", ylab = "Wage", colour = I(cbPalette[1]), alpha = I(0.75)) +
  stat_smooth(method = "lm", formula = y ~ cut(x, breaks = c(-Inf, 35, 65, Inf)), lwd = 1.25) + 
  theme_bw()

qplot(data = Wage, x = age, y = wage, 
      xlab = "Age", ylab = "Wage", colour = I(cbPalette[1]), alpha = I(0.75)) +
  stat_smooth(method = "lm", formula = y ~ cut(x, breaks = c(-Inf, 25, 35, 65, Inf)), lwd = 1.25) + 
  theme_bw()