-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path02_PredictingTheAgeOfBats.qmd
294 lines (203 loc) · 7.57 KB
/
02_PredictingTheAgeOfBats.qmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
---
title: "Introduction to machine learning with tidymodels"
subtitle: Predicting the age of bats
author: Dr Jamie Soul
format: html
editor: visual
title-block-banner: true
toc: true
self-contained: true
---
{fig-align="center"}
## Let's look at a genomics example!
Let's try to predict the age of bats from their skin DNA methylation data. The data is taken from:
{fig-align="center"}
## Loading the metapackage
```{r}
library(tidymodels)
```
## Prepare the data
The GEOquery library allows us to download the normalised methylation beta values from NCBI GEO.
```{r}
#| message: false
library(GEOquery)
library(tidyverse)
library(skimr)
#retrieve the dataset - note it always returns a list with one element per platform even if only one platform.
geo <- getGEO( "GSE164127")[[1]]
```
Genomics datasets for machine learning tend to have many variables/features e.g CpGs, genes, proteins and relatively few observations.
```{r}
#have lots of cpgs
dim(exprs(geo))
```
Beta values represent percentage of measured beads with that site methylated. Beta values are between 0 (completely unmethytlated ) and 1 (complely methylated). Note the data is pre-normalised for us. In best practice we'd pre-process the train and test data completely independently, i.e not normalised together at all.
We can extract the table of samples and the beta values of every CpG.
```{r}
head(exprs(geo[,1:6]))
```
We can also extract the corresponding metadata. The metadata includes the age which we are trying to predict.
```{r}
skim(pData(geo))
```
Let's keep those samples which have a known age that we can use for modelling.
```{r}
geo$`age (years):ch1` <- as.numeric(geo$`age (years):ch1`)
geo <- geo[ , geo$`canbeusedforagingstudies:ch1` =="yes" & geo$`tissue:ch1` == "Skin" & !is.na(geo$`age (years):ch1`)]
```
To make this faster to run and to show we can do ML on smaller datasets let's use just one of the bat species to train on.
Let's train a model using data from: {fig-align="center"}
To test how generalisable the model is we try to use the model across species to predict the age of:
{fig-align="center" width="300"}
```{r}
#helper function to extract a data matrix for a particular bat species
processData <- function(species,geo){
geo_filtered <- geo[,geo$organism_ch1 == species]
methyl_filtered <- as.data.frame(t(exprs(geo_filtered)))
methyl_filtered$age <- sqrt(as.numeric(geo_filtered$`age (years):ch1`)+1)
return(methyl_filtered)
}
#Let's keep only 1k random CpGs to help training speed for this workshop
set.seed(42)
keep <- sample.int(nrow(geo),1000)
geo <- geo[keep,]
#get the data from model building and testing
methyl_spearbat <- processData("Phyllostomus hastatus",geo)
methyl_bigbrownbat <- processData("Eptesicus fuscus",geo)
```
## Create the training split
Keeping 20% of the data for a final test. The remaining 80% will be used to train the parameters of the model.
```{r}
#Split the data into train and test
methyl_spearbat_split <- initial_split(methyl_spearbat,prop = 0.8,strata=age)
methyl_spearbat_train <- training(methyl_spearbat_split)
methyl_spearbat_test <- testing(methyl_spearbat_split)
```
## Create the recipe
Similar to before we define the outcome and scale-centre the rest of the predictors.
```{r}
#define the recipe
methyl_recipe <-
recipe(methyl_spearbat_train) %>%
update_role(everything()) %>%
update_role(age,new_role = "outcome") %>%
step_center(all_predictors()) %>%
step_scale(all_predictors())
```
## Select the model
Let's use a GLMNet model which allows use to penalise the inclusion of variables to prevent overfitting and keep the model sparse. This is useful if we want to identify the minimal panel of biological features that are sufficient to get a good prediction e.g for a biomarker panel.
mixture = 1 is known as a lasso model. In this model we need to tune the penalty (lambda) which controls the downweighting of variables (regulatisation).
`tune` marks the penalty parameter as needing optimisation.
```{r}
#use glmnet model
glmn_fit <-
linear_reg( mixture = 1, penalty = tune()) %>%
set_engine("glmnet")
```
Let's cross validate within the training dataset to allow us to tune the parameters
```{r}
#5-fold cross validation
folds <- vfold_cv(methyl_spearbat_train, v = 5, strata = age, breaks= 2)
```
## Create the workflow
We build the workflow from the model and the recipe.
```{r}
#build the workflow
methyl_wf <- workflow() %>%
add_model(glmn_fit) %>%
add_recipe(methyl_recipe)
```
## Define the tuning search space
Here we'll check the performance of the model as we vary the penalty.
```{r}
#define a sensible search
lasso_grid <- tibble(penalty = 10^seq(-3, 0, length.out = 50))
```
## Run the tuning workflow
We can use multiple cpus to help speed up the tuning.
```{r}
#| message: false
#Using 6 cores
library(doParallel)
cl <- makeCluster(6)
registerDoParallel(cl)
#tune the model
methyl_res <- methyl_wf %>%
tune_grid(resamples = folds,
grid = lasso_grid,
control = control_grid(save_pred = TRUE),
metrics = metric_set(rmse))
methyl_res
```
## How does the regularisation affect the performance?
We can find the best penalty value that minimises the error in our age prediction (rmse).
```{r}
autoplot(methyl_res)
```
## Finalise the model
Get the best model parameters
```{r}
best_mod <- methyl_res %>% select_best("rmse")
best_mod
```
Get the final model
```{r}
#fit on the training data using the best parameters
final_fitted <- finalize_workflow(methyl_wf, best_mod) %>%
fit(data = methyl_spearbat_train)
```
## Test the performance
Look at the performance in the test dataset. How well does the clock work on a different species?
```{r}
#get the test performance
methyl_spearbat_aug <- augment(final_fitted, methyl_spearbat_test)
rmse(methyl_spearbat_aug,truth = age, estimate = .pred)
plot(methyl_spearbat_aug$.pred,methyl_spearbat_aug$age)
#try on the different species
methyl_bigbrownbat <- augment(final_fitted, methyl_bigbrownbat)
plot(methyl_bigbrownbat$age,methyl_bigbrownbat$.pred)
rmse(methyl_bigbrownbat,truth = age, estimate = .pred)
```
## What CpGs are important?
We can use the coefficients from the model to determine what CpGs are driving the prediction. We can also see how many variables have been retained in the model using our tuned penalty value.
```{r}
#| message: false
library(vip)
library(cowplot)
#get the importance from glmnet using the select penalty
importance <- final_fitted %>%
extract_fit_parsnip() %>%
vi(lambda = best_mod$penalty) %>%
mutate(
Importance = abs(Importance),
Variable = fct_reorder(Variable, Importance)
)
#how many CpGs are retained
table(importance$Importance>0)
```
## Plot the importance of the top CpGs and their direction
```{r}
#plot the top 10 CpGs
importance %>% slice_max(Importance,n=10) %>%
ggplot(aes(x = Importance, y = Variable, fill = Sign)) +
geom_col() +
scale_x_continuous(expand = c(0, 0)) +
labs(y = NULL) + theme_cowplot()
```
## Plot the top predictive CpG beta values versus age
This highlights how you can use machine learning to identify a small number of discrimative features.
```{r}
#helper function to plot a CpG beta values against age
plotCpG <- function(cpg,dat){
ggplot(dat,aes(x=!!sym(cpg),y=age)) +
geom_point() +
theme_cowplot()
}
#plot the most important CpGs
importance %>%
slice_max(Importance,n=4) %>%
pull(Variable) %>%
as.character() %>%
map(plotCpG,methyl_spearbat) %>%
plot_grid(plotlist = .)
```