-
Notifications
You must be signed in to change notification settings - Fork 7
/
xgboost.r
19 lines (15 loc) · 1.11 KB
/
xgboost.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
library(tidyverse)
library(xgboost)
ind<-sample(2,nrow(diamonds),replace = T,prob = c(0.7,0.3))
train.set<-diamonds[ind==1,]
test.set<-diamonds[ind==2,]
xgb.train<-bind_cols(select_if(train.set,is.numeric),model.matrix(~cut-1,train.set) %>% as.tibble(),model.matrix(~color-1,train.set) %>% as.tibble(),model.matrix(~clarity-1,train.set) %>% as.tibble())
xgboost.train<-xgb.DMatrix(data = as.matrix(select(xgb.train,-price)),label=xgb.train$price)
xgb.test<-bind_cols(select_if(test.set,is.numeric),model.matrix(~cut-1,test.set) %>% as.tibble(),model.matrix(~color-1,test.set) %>% as.tibble(),model.matrix(~clarity-1,test.set) %>% as.tibble())
xgboost.test<-xgb.DMatrix(data = select(xgb.test,-price) %>% as.matrix(),label=xgb.test$price)
param<-list(eval_metric='rmse',gamma=1,max_depth=6,nthread = 3)
xg.model<-xgb.train(data = xgboost.train,params = param,watchlist = list(test=xgboost.test),nrounds = 500,early_stopping_rounds = 60,
print_every_n = 30)
xg.predict<-predict(xg.model,xgboost.test)
mse.xgb<-sqrt(mean((test.set$price-xg.predict)^2))
plot((test.set$price-xg.predict))