-
Notifications
You must be signed in to change notification settings - Fork 1
/
run.py
67 lines (62 loc) · 2.24 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
# from recbole.quick_start.quick_start import run_recbole
parameter_dict = {
# 'neg_sampling': {'popularity': 100},
'learning_rate': 0.001, #0.001
#'weight_decay': 0.01,
'train_batch_size': 2048, #512 for gowalla
'eval_batch_size': 2048,
'train_neg_sample_args': None, #
'neg_sampling': None,
'mask_ratio': 0.2,
'hidden_size': 128,
'inner_size': 256,
'n_layers': 2,
'n_heads': 8,
'hidden_dropout_prob': 0.2,
'attn_dropout_prob': 0.2,
'hidden_act': 'gelu',
'layer_norm_eps': 1e-12,
'initializer_range': 0.02,
# 'loss_type': 'CE',
# 'eval_args': {'split': {'LS': 'valid_and_test'}, 'order': 'TO', 'mode': 'pop100', 'group_by': 'user'},
'topk': 10,
'metrics': ['Recall', 'MRR', 'NDCG'],
'valid_metric': 'NDCG@10'
}
# print(run_recbole(model='BERT4Rec', dataset='gowalla-merged', config_file_list=['gowalla.yaml'], config_dict=parameter_dict))
run_recbole(model='SASRec', dataset='ML-1M', config_dict=parameter_dict)
# FDSA, SASRecF
# from recbole.quick_start import run_recbole
# parameter_dict = {
# # 'neg_sampling': {'popularity': 100},
# 'learning_rate': 0.001, #0.001
# #'weight_decay': 0.01,
# 'train_batch_size': 1024, #2048
# 'eval_batch_size': 1024, #2048
# 'neg_sampling': None,
# 'mask_ratio': 0.2,
# 'hidden_size': 64, #128
# 'inner_size': 256,
# 'n_layers': 2,
# 'n_heads': 8,
# 'hidden_dropout_prob': 0.2,
# 'attn_dropout_prob': 0.2,
# 'hidden_act': 'gelu',
# 'layer_norm_eps': 1e-12,
# 'initializer_range': 0.02,
# # 'loss_type': 'CE',
# # 'eval_args': {'split': {'LS': 'valid_and_test'}, 'order': 'TO', 'mode': 'pop100', 'group_by': 'user'},
# 'topk': 10,
# 'metrics': ['Recall', 'MRR', 'NDCG'],
# 'valid_metric': 'NDCG@10',
# 'train_neg_sample_args': None,
# # 'load_col':
# # {'inter': ['user_id', 'item_id', 'rating', 'timestamp'],
# # 'item': ['item_id', 'genre']},
# # 'selected_features': ['genre'],
# # # gowalla
# # 'load_col':
# # {'inter': ['user_id', 'item_id', 'rating', 'timestamp']},
# # 'selected_features': ['item_id'],
# }
# run_recbole(model='SASRec', dataset='gowalla', config_dict=parameter_dict)