-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathbpr_postclick.py
115 lines (93 loc) · 5.15 KB
/
bpr_postclick.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import numpy as np
import sys, os
import argparse
from openrec import ModelTrainer
from openrec.recommenders import BPR
from openrec.utils import Dataset
from openrec.utils.evaluators import AUC
from openrec.utils.samplers import RandomPairwiseSampler, EvaluationSampler
from stratified_pairwise_sampler import StratifiedPairwiseSampler
from dataloader import *
### training parameter ###
total_iter = 10000 # iterations for training
batch_size = 1000 # training batch size
eval_iter = 1000 # iteration of evaluation
save_iter = eval_iter # iteration of saving model
### embeding ###
dim_user_embed = 100 # dimension of user embedding
dim_item_embed = 100 # dimension of item embedding
def exp(dataset, l2_reg, p_n_ratio, eval_explicit, save_log, eval_rank):
if dataset == 'spotify':
data = loadSpotify()
elif dataset == 'bytedance':
data = loadByteDance()
else:
print ("Unsupported dataset...")
return
# save logging and model
log_dir = "validation_logs/{}_{}_{}_{}_{}/".format(dataset, l2_reg, p_n_ratio, eval_explicit, eval_rank)
os.popen("mkdir -p %s" % log_dir).read()
if save_log:
log = open(log_dir + "validation.log", "w")
sys.stdout = log
# prepare train, val, test sets
train_dataset = Dataset(data['train'], data['total_users'], data['total_items'], name='Train')
if p_n_ratio is None:
train_sampler = RandomPairwiseSampler(batch_size=batch_size, dataset=train_dataset, num_process=5)
else:
train_sampler = StratifiedPairwiseSampler(batch_size=batch_size, dataset=train_dataset, p_n_ratio=p_n_ratio, num_process=5)
if p_n_ratio > 0.0:
print ("Re-weighting implicit negative feedback")
else:
print ("Corrected negative feedback labels but not re-weighting")
eval_num_neg = None if eval_explicit else 500 # num of negative samples for evaluation
if eval_rank:
# show evaluation metrics for click-complete and click-skip items separately
pos_dataset = Dataset(data['pos_test'], data['total_users'], data['total_items'],
implicit_negative=not eval_explicit, name='Pos_Test', num_negatives=eval_num_neg)
neg_dataset = Dataset(data['neg_test'], data['total_users'], data['total_items'],
implicit_negative=not eval_explicit, name='Neg_Test', num_negatives=eval_num_neg)
pos_sampler = EvaluationSampler(batch_size=batch_size, dataset=pos_dataset)
neg_sampler = EvaluationSampler(batch_size=batch_size, dataset=neg_dataset)
eval_samplers = [pos_sampler, neg_sampler]
else:
val_dataset = Dataset(data['val'], data['total_users'], data['total_items'],
implicit_negative=not eval_explicit, name='Val', num_negatives=eval_num_neg)
test_dataset = Dataset(data['test'], data['total_users'], data['total_items'],
implicit_negative=not eval_explicit, name='Test', num_negatives=eval_num_neg)
val_sampler = EvaluationSampler(batch_size=batch_size, dataset=val_dataset)
test_sampler = EvaluationSampler(batch_size=batch_size, dataset=test_dataset)
eval_samplers = [val_sampler, test_sampler]
# set evaluators
auc_evaluator = AUC()
evaluators = [auc_evaluator]
# set model parameters
model = BPR(l2_reg=l2_reg,
batch_size=batch_size,
total_users=train_dataset.total_users(),
total_items=train_dataset.total_items(),
dim_user_embed=dim_user_embed,
dim_item_embed=dim_item_embed,
save_model_dir=log_dir,
train=True,
serve=True)
# set model trainer
model_trainer = ModelTrainer(model=model)
model_trainer.train(total_iter=total_iter,
eval_iter=eval_iter,
save_iter=save_iter,
train_sampler=train_sampler,
eval_samplers=eval_samplers,
evaluators=evaluators)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Parse parameters')
parser.add_argument('--dataset', type=str, default='bytedance', help='dataset to use')
parser.add_argument('--l2_reg', type=float, default=0.01, help='l2 regularization of latent factor')
parser.add_argument('--p_n_ratio', type=float, default=None, help='pos-neg pair ratio during sampling')
parser.add_argument('--eval_explicit', action='store_true', help='turn on to use labels to evaluate, by default treat click as positive and non-click as negative')
parser.add_argument('--eval_rank', action='store_true', help='show ranking accuracy for pos and neg samples')
parser.add_argument('--log', action='store_true', help='turn on for logging results to file, by default will print on screen')
args = parser.parse_args()
print (args)
# run experiments
exp(dataset=args.dataset, l2_reg=args.l2_reg, p_n_ratio=args.p_n_ratio, eval_explicit=args.eval_explicit, save_log=args.log, eval_rank=args.eval_rank)