-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_logistic_label_wise_pairwise_SVRG_BB_new.m
138 lines (117 loc) · 4.43 KB
/
train_logistic_label_wise_pairwise_SVRG_BB_new.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
function [ W ] = train_logistic_label_wise_pairwise_SVRG_BB_new( X, Y, lambda, alpha )
%UNTITLED2 Summary of this function goes here
% paiwise surrogate loss with base loss is logistic
% alpha: learning_rate
% lambda_1 for l2 norm
[num_instance, num_feature] = size(X);
num_label = size(Y, 2);
W = zeros(num_feature, num_label);
for j = 1: num_label
fprintf('Training the label %d:\n', j);
W(:,j) = train_model_for_one_label(X, Y(:,j), lambda, alpha);
end
end
function [w] = train_model_for_one_label(X, y, lambda, alpha)
[num_instance, num_feature] = size(X);
w = zeros(num_feature, 1);
p_list = find(y > 0);
q_list = find(y < 0);
num_pos = length(p_list);
num_neg = length(q_list);
if num_pos == 0 || num_neg == 0
return;
end
% Do serveral SGD steps first
for i = 1: 10
pos_index_tmp = randi(num_pos);
pos_index = p_list(pos_index_tmp);
neg_index_tmp = randi(num_neg);
neg_index = q_list(neg_index_tmp);
GD_one = calculate_one_gradient(X, pos_index, neg_index, w, lambda);
w = w - alpha * GD_one;
% size(w)
end
num_s = 30;
%num_s = 0;
obj = zeros(num_s, 1);
m = 2 * (num_pos * num_neg);
epsilon = 10^-6;
for i = 1: num_s
w1 = w;
fG1 = calculate_all_gradient(X, p_list, q_list, w1, lambda);
if i > 1
if i > 2 && abs(obj(i-1, 1) - obj(i-2, 1)) / obj(i-2, 1) <= epsilon
break;
end
alpha = norm(w1-w0, 'fro')^2 / trace((w1-w0)'*(fG1-fG0)) / m;
end
fG0 = fG1;
w0 = w1;
for j = 1: m
pos_index_tmp = randi(num_pos);
pos_index = p_list(pos_index_tmp);
neg_index_tmp = randi(num_neg);
neg_index = q_list(neg_index_tmp);
GD_one = calculate_one_gradient(X, pos_index, neg_index, w, lambda);
GD_ = calculate_one_gradient(X, pos_index, neg_index, w1, lambda);
w = w - alpha * (GD_one - GD_ + fG1);
%if isnan(W)
% return;
%end
end
obj(i,1) = calculate_objective_function(X, p_list, q_list, w, lambda);
fprintf('Step %d: the objective function value is %.5f\n', i, obj(i,1));
end
end
function [f_value] = calculate_objective_function(X, p_list, q_list, w, lambda)
f_value = 0.5 * lambda * norm(w, 'fro')^2;
f_value_loss = 0;
pos_num = length(p_list);
neg_num = length(q_list);
for p = pos_num
for q = neg_num
% hinge loss
% f_value_loss = f_value_loss + max(0, 1 - dot(w, X(p_list(p),:) - X(q_list(q),:)));
% logistic loss
f_value_loss = f_value_loss + log(1 + exp(- dot(w, X(p_list(p),:) - X(q_list(q),:))));
end
end
f_value = f_value + f_value_loss / (pos_num * neg_num);
end
function [grad] = calculate_all_gradient(X, p_list, q_list, w, lambda)
num_feature = size(X, 2);
grad = lambda * w;
Z_m = zeros(num_feature, 1);
grad_loss = Z_m;
pos_num = length(p_list);
neg_num = length(q_list);
for p = 1: pos_num
for q = 1: neg_num
% logistic loss
grad_loss = grad_loss + (X(q_list(q),:) - X(p_list(p),:))' / (1 + exp(dot(w, X(p_list(p),:) - X(q_list(q),:))));
% hinge loss
tmp_grad = Z_m;
if dot(w, X(p_list(p),:) - X(q_list(q),:)) <= 1
tmp_grad = - (X(p_list(p),:) - X(q_list(q),:))';
end
grad_loss = grad_loss + tmp_grad;
end
end
grad = grad + grad_loss / (pos_num * neg_num);
end
function [grad_one] = calculate_one_gradient(X, pos_index, neg_index, w, lambda)
% input: size(x) = [1, num_feature], size(y) = [1, num_class]
% Calculate logistic loss gradient
% [num_feature, num_class] = size(W);
% Z_m = zeros(num_feature, num_class);
grad_one = lambda * w;
% logistic loss
grad_rank = (X(neg_index,:) - X(pos_index,:))' ./ (1 + exp(dot(w, X(pos_index,:) - X(neg_index,:))));
% hinge loss
% num_feature = size(X, 2);
% grad_rank = zeros(num_feature, 1);
% if dot(w, X(pos_index,:) - X(neg_index,:)) <= 1
% grad_rank = -(X(pos_index,:) - X(neg_index,:))';
% end
grad_one = grad_one + grad_rank;
end