library(ggplot2)
library(ISLR)
library(MASS)
library(partykit)
library(caret)
library(rpart)
library(randomForest)
library(pROC)
# Read in the marketing data
marketing <- read.table("http://www.andrew.cmu.edu/user/achoulde/94842/data/bank-full.csv",
header = TRUE, sep = ";")
set.seed(531)
# Upsample the data to artifically overcome sample imbalance
marketing.more.idx <- sample(which(marketing$y == "yes"), 15000, replace = TRUE)
marketing.upsample <- rbind(marketing,
marketing[marketing.more.idx, ])
# Randomly select 20% of the data to be held out for model validation
test.indexes <- sample(1:nrow(marketing.upsample),
round(0.2 * nrow(marketing.upsample)))
train.indexes <- setdiff(1:nrow(marketing.upsample), test.indexes)
# Just pull the covariates available to marketers (cols 1:8) and the outcome (col 17)
marketing.train <- marketing.upsample[train.indexes, c(1:8, 17)]
marketing.test <- marketing.upsample[test.indexes, c(1:8, 17)]
In this problem we’ll assume that we have a binary classification problem where our outcome variable \(Y \in \{0, 1\}\). Your main task is to construct a function that calculates various kinds of classifier performance metrics.
set.seed(826)
score.fake <- runif(200)
y.fake <- as.numeric(runif(200) <= score.fake)
Argument | Description |
---|---|
score |
length-n vector giving a score for every observation |
y |
true observed class label for each observation |
cutoff |
score cutoff: classify \(\hat y = 1\) if score >= cutoff |
type |
which performance metric(s) to return. type = all calculates all |
Your output will be a list containing the following elements
Argument | Description |
---|---|
conf.mat |
the confusion matrix for the classifier |
perf |
a data frame containing all of the desired metrics |
Example output:
# Cutoff 0.6
classMetrics(score.fake, y.fake, cutoff = 0.6, type = "all")
$conf.mat
observed
predicted 0 1
0 82 31
1 15 72
$perf
value
accuracy 0.7700000
sensitivity 0.6990291
specificity 0.8453608
ppv 0.8275862
npv 0.7256637
precision 0.8275862
recall 0.6990291
# Cutoff 0.2
classMetrics(score.fake, y.fake, cutoff = 0.2, type = "all")
$conf.mat
observed
predicted 0 1
0 36 3
1 61 100
$perf
value
accuracy 0.6800000
sensitivity 0.9708738
specificity 0.3711340
ppv 0.6211180
npv 0.9230769
precision 0.6211180
recall 0.9708738
# Precision and recall only
classMetrics(score.fake, y.fake, cutoff = 0.2, type = c("precision", "recall"))
$conf.mat
observed
predicted 0 1
0 36 3
1 61 100
$perf
value
precision 0.6211180
recall 0.9708738
classMetrics <- function(score, y, cutoff,
type = c("all", "accuracy", "sensitivity",
"specificity", "ppv", "npv", "precision",
"recall")) {
# This command throws an error if the user specifies a "type" that
# isn't supported by this function
type <- match.arg(type, several.ok = TRUE)
# Edit me
}
y = NULL
, the x-axis variable should be taken to be score
, and should range from the smallest to the largest value of score
. If flip.x = TRUE
, you should plot 1 - xvar_metric
on the x-axis. E.g., if xvar = Specificity
and flip.x = TRUE
, your plot should have 1 - Specificity
as the x-axis variable.Example output:
plotClassMetrics <- function(score, y, xvar = NULL, yvar = c("accuracy", "sensitivity",
"specificity", "ppv", "npv", "precision",
"recall"),
flip.x = FALSE) {
yvar <- match.arg(yvar)
# Edit me
}
# ROC curve
test <- plotClassMetrics(score.fake, y.fake, xvar = "specificity", yvar = "sensitivity",
flip.x = TRUE)
plotClassMetrics(score.fake, y.fake, yvar = "precision")
We’ll need to construct
rpart
objects instead oftree
objects in order to use the more advanced plotting routine from thepartykit
library. The syntax forrpart
is similar to that oftree
, and was demonstrated on the Lab for week 4. For additional details, you may refer to the following link.
We will be using the
marketing
data, which has been split intomarketing.train
andmarketing.test
in the preamble of this document. All model fitting should be done onmarketing.train
. The outcome variable in the data set isy
, denoting whether the customer opened up a CD or not.
This data comes from a Portuguese banking institution that ran a marketing campaign to try to get clients to subscribe to a “term deposit”" (a CD). A CD is an account that you can put money into that guarantees fixed interest rate over a certain period of time (e.g., 2 years). The catch is that if you try to withdraw your money before the term ends, you will typically incur heavy penalties or “early withdrawal fees”.
Suppose that you’re hired as a decision support analyst at this bank and your first job is to use the data to figure out who the marketing team should contact for their next CD marketing campaign. i.e., they pull up new spreadsheet that contains the contact information, age, job, marital status, education level, default history, mortgage status, and personal loan status for tens of thousands of clients, and they want you to tell them who they should contact.
rpart()
function. Call this tree marketing.tree
. The syntax is exactly the same as for the tree
function you saw on Lab 4. Use the plot
and text
functions to visualize the tree. Show a text print-out of the tree. Which variables get used in fitting the tree?# Edit me
as.party
command converts the rpart
tree you fit in part (a) to a party
object that has a much better plot function. Run plot
on the object created below. Also run the print
function."yes"
or "no"
?# marketing.party <- as.party(marketing.tree)
# Edit me
cp = 0.002
, while ensuring that no single node contains fewer than minsplit = 100
observations.plotcp
command on this tree to get a plot of the Cross-validated error. Also look at the cptable
attribute of marketing.full
. Observe that all of the errors are reported relative to that of the 1-node “tree”.# marketing.full <- rpart(y ~ ., data = marketing.train,
# control = rpart.control(minsplit=100, cp=0.002))
# Edit me
cp
values shown. Apply the 1-SE rule to determine which value of cp
to use for pruning. Print this value of cp
.# Edit me
prune
command (prune(rpart.fit, cp = )
) to prune marketing.full
to the level of complexity you settled on in part (e). Call your pruned tree marketing.pruned
. Display a text print-out of your tree.# Edit me
The questions below all refer to
marketing.pruned
.
marketing.pruned
tree into a party
object. Plot these results, supplying the argument gp = gpar(fontsize = 10)
to make the text size more easily legible. Notice the use of gpar
to set the fontsize
for the plot.# Uncomment the code below
# marketing.pruned.party <- as.party(marketing.pruned)
# Edit me
predict
function on your pruned tree to get estimated probabilities of opening a cd for everyone in marketing.test
. Use your classMetrics
function to get classification metrics at probability cutoff
values of 0.25
, 0.4
and 0.5
. Use your plotClassMetrics
command to construct an ROC curve.# Edit me
marketing.test
for this calculation.)# Edit me
randomForest
command to fit a random forest to marketing.train
. Call your fit marketing.rf
. Show a print-out of your random Forest fit. This print-out contains a confusion matrix. Are the predicted classes given as the rows or columns of this table?# Edit me
# Edit me
predict
command to obtain probability estimates on the test data. Use your classMetrics
function to calculate performance metrics at cutoff = 0.3
. Compare the metrics to those of the pruned tree marketing.pruned
at the same cutoff
.# Edit me
roc
function from the pROC
package to get an ROC curve for the random forest. Overlay the ROC curve for the pruned tree (use steelblue
as the colour). Calculate the AUC for both methods. Do we do better with random forests than with a single tree? Are most of the gains at high or low values of Specificity? i.e., is the random forest performing better in the regime we actually care about?# Edit me