-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c42c6b7
commit 98697b6
Showing
8 changed files
with
700 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,3 +22,9 @@ TransX/Tf/result/ | |
Graph/data/cora/README | ||
|
||
gat/__pycache__/ | ||
|
||
EGES/__pycache__/ | ||
|
||
EGES/data_cache/ | ||
|
||
EGES/data/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# -*- coding:utf-8 -*- | ||
# @Time : 2021/9/12 12:26 上午 | ||
# @Author : huichuan LI | ||
# @File : __init__.py.py | ||
# @Software: PyCharm |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import numpy as np | ||
|
||
|
||
def create_alias_table(area_ratio): | ||
""" | ||
:param area_ratio: sum(area_ratio)=1 | ||
:return: accept,alias | ||
""" | ||
l = len(area_ratio) | ||
area_ratio = [prop * l for prop in area_ratio] | ||
accept, alias = [0] * l, [0] * l | ||
small, large = [], [] | ||
|
||
for i, prob in enumerate(area_ratio): | ||
if prob < 1.0: | ||
small.append(i) | ||
else: | ||
large.append(i) | ||
|
||
while small and large: | ||
small_idx, large_idx = small.pop(), large.pop() | ||
accept[small_idx] = area_ratio[small_idx] | ||
alias[small_idx] = large_idx | ||
area_ratio[large_idx] = area_ratio[large_idx] - \ | ||
(1 - area_ratio[small_idx]) | ||
if area_ratio[large_idx] < 1.0: | ||
small.append(large_idx) | ||
else: | ||
large.append(large_idx) | ||
|
||
while large: | ||
large_idx = large.pop() | ||
accept[large_idx] = 1 | ||
while small: | ||
small_idx = small.pop() | ||
accept[small_idx] = 1 | ||
|
||
return accept, alias | ||
|
||
|
||
def alias_sample(accept, alias): | ||
""" | ||
:param accept: | ||
:param alias: | ||
:return: sample index | ||
""" | ||
N = len(accept) | ||
i = int(np.random.random()*N) | ||
r = np.random.random() | ||
if r < accept[i]: | ||
return i | ||
else: | ||
return alias[i] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
import pandas as pd | ||
import numpy as np | ||
from itertools import chain | ||
import pickle | ||
import time | ||
import networkx as nx | ||
|
||
from sklearn.preprocessing import LabelEncoder | ||
import argparse | ||
from walker import RandomWalker | ||
|
||
|
||
def cnt_session(data, time_cut=30, cut_type=2): | ||
sku_list = data['sku_id'] | ||
time_list = data['action_time'] | ||
type_list = data['type'] | ||
session = [] | ||
tmp_session = [] | ||
for i, item in enumerate(sku_list): | ||
if type_list[i] == cut_type or ( | ||
i < len(sku_list) - 1 and (time_list[i + 1] - time_list[i]).seconds / 60 > time_cut) or i == len( | ||
sku_list) - 1: | ||
tmp_session.append(item) | ||
session.append(tmp_session) | ||
tmp_session = [] | ||
else: | ||
tmp_session.append(item) | ||
return session | ||
|
||
|
||
def get_session(action_data, use_type=None): | ||
if use_type is None: | ||
use_type = [1, 2, 3, 5] | ||
action_data = action_data[action_data['type'].isin(use_type)] | ||
action_data = action_data.sort_values(by=['user_id', 'action_time'], ascending=True) | ||
group_action_data = action_data.groupby('user_id').agg(list) | ||
session_list = group_action_data.apply(cnt_session, axis=1) | ||
return session_list.to_numpy() | ||
|
||
|
||
def get_graph_context_all_pairs(walks, window_size): | ||
all_pairs = [] | ||
for k in range(len(walks)): | ||
for i in range(len(walks[k])): | ||
for j in range(i - window_size, i + window_size + 1): | ||
if i == j or j < 0 or j >= len(walks[k]): | ||
continue | ||
else: | ||
all_pairs.append([walks[k][i], walks[k][j]]) | ||
return np.array(all_pairs, dtype=np.int32) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='manual to this script') | ||
parser.add_argument("--data_path", type=str, default='./data/') | ||
parser.add_argument("--p", type=float, default=0.25) | ||
parser.add_argument("--q", type=float, default=2) | ||
parser.add_argument("--num_walks", type=int, default=10) | ||
parser.add_argument("--walk_length", type=int, default=10) | ||
parser.add_argument("--window_size", type=int, default=5) | ||
args = parser.parse_known_args()[0] | ||
|
||
action_data = pd.read_csv(args.data_path + 'action_head.csv', parse_dates=['action_time']).drop('module_id', | ||
axis=1).dropna() | ||
all_skus = action_data['sku_id'].unique() | ||
all_skus = pd.DataFrame({'sku_id': list(all_skus)}) | ||
sku_lbe = LabelEncoder() | ||
all_skus['sku_id'] = sku_lbe.fit_transform(all_skus['sku_id']) | ||
action_data['sku_id'] = sku_lbe.transform(action_data['sku_id']) | ||
|
||
print('make session list\n') | ||
start_time = time.time() | ||
session_list = get_session(action_data, use_type=[1, 2, 3, 5]) | ||
session_list_all = [] | ||
for item_list in session_list: | ||
for session in item_list: | ||
if len(session) > 1: | ||
session_list_all.append(session) | ||
|
||
print('make session list done, time cost {0}'.format(str(time.time() - start_time))) | ||
|
||
# session2graph | ||
node_pair = dict() | ||
for session in session_list_all: | ||
for i in range(1, len(session)): | ||
if (session[i - 1], session[i]) not in node_pair.keys(): | ||
node_pair[(session[i - 1], session[i])] = 1 | ||
else: | ||
node_pair[(session[i - 1], session[i])] += 1 | ||
|
||
in_node_list = list(map(lambda x: x[0], list(node_pair.keys()))) | ||
out_node_list = list(map(lambda x: x[1], list(node_pair.keys()))) | ||
weight_list = list(node_pair.values()) | ||
graph_df = pd.DataFrame({'in_node': in_node_list, 'out_node': out_node_list, 'weight': weight_list}) | ||
graph_df.to_csv('./data_cache/graph.csv', sep=' ', index=False, header=False) | ||
|
||
G = nx.read_edgelist('./data_cache/graph.csv', create_using=nx.DiGraph(), nodetype=None, data=[('weight', int)]) | ||
walker = RandomWalker(G, p=args.p, q=args.q) | ||
print("Preprocess transition probs...") | ||
walker.preprocess_transition_probs() | ||
|
||
session_reproduce = walker.simulate_walks(num_walks=args.num_walks, walk_length=args.walk_length, workers=4, | ||
verbose=1) | ||
session_reproduce = list(filter(lambda x: len(x) > 2, session_reproduce)) | ||
|
||
# add side info | ||
product_data = pd.read_csv(args.data_path + 'jdata_product.csv').drop('market_time', axis=1).dropna() | ||
|
||
all_skus['sku_id'] = sku_lbe.inverse_transform(all_skus['sku_id']) | ||
print("sku nums: " + str(all_skus.count())) | ||
sku_side_info = pd.merge(all_skus, product_data, on='sku_id', how='left').fillna(0) | ||
|
||
# id2index | ||
for feat in sku_side_info.columns: | ||
if feat != 'sku_id': | ||
lbe = LabelEncoder() | ||
sku_side_info[feat] = lbe.fit_transform(sku_side_info[feat]) | ||
else: | ||
sku_side_info[feat] = sku_lbe.transform(sku_side_info[feat]) | ||
|
||
sku_side_info = sku_side_info.sort_values(by=['sku_id'], ascending=True) | ||
sku_side_info.to_csv('./data_cache/sku_side_info.csv', index=False, header=False, sep='\t') | ||
|
||
# # get pair | ||
all_pairs = get_graph_context_all_pairs(session_reproduce, args.window_size) | ||
np.savetxt('./data_cache/all_pairs', X=all_pairs, fmt="%d", delimiter=" ") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
# -*- coding:utf-8 -*- | ||
# @Time : 2021/9/12 12:26 上午 | ||
# @Author : huichuan LI | ||
# @File : eges.py | ||
# @Software: PyCharm | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
|
||
class EGES_Model(keras.Model): | ||
def __init__(self, num_nodes, num_feat, feature_lens, n_sampled=100, embedding_dim=128, lr=0.001, **kwargs): | ||
self.n_samped = n_sampled | ||
self.num_feat = num_feat | ||
self.feature_lens = feature_lens | ||
self.embedding_dim = embedding_dim | ||
self.num_nodes = num_nodes | ||
self.lr = lr | ||
self.num_nodes = num_nodes | ||
self.embedding_dim = embedding_dim | ||
super(EGES_Model, self).__init__(**kwargs) | ||
|
||
def build(self, input_shapes): | ||
# noise-contrastive estimation | ||
self.nce_w = self.add_weight( | ||
name="nce_w", shape=[self.num_nodes, self.embedding_dim], | ||
initializer=keras.initializers.TruncatedNormal(0., 0.1)) # [n_vocab, emb_dim] | ||
self.nce_b = self.add_weight( | ||
name="nce_b", shape=(self.num_nodes,), | ||
initializer=keras.initializers.Constant(0.1)) # [n_vocab, ] | ||
|
||
cat_embedding_vars = [] | ||
for i in range(self.num_feat): | ||
embedding_var = self.add_weight( | ||
shape=[self.feature_lens[i], self.embedding_dim] | ||
, initializer=keras.initializers.TruncatedNormal(0., 0.1), | ||
name='embedding' + str(i), | ||
trainable=True) | ||
cat_embedding_vars.append(embedding_var) | ||
self.cat_embedding = cat_embedding_vars | ||
self.alpha_embedding = self.add_weight( | ||
name="nce_b", shape=(self.num_nodes, self.num_feat), | ||
initializer=keras.initializers.Constant(0.1)) | ||
|
||
def attention_merge(self): | ||
embed_list = [] | ||
for i in range(self.num_feat): | ||
cat_embed = tf.nn.embedding_lookup(self.cat_embedding[i], self.batch_features[:, i]) | ||
embed_list.append(cat_embed) | ||
stack_embed = tf.stack(embed_list, axis=-1) | ||
# attention merge | ||
alpha_embed = tf.nn.embedding_lookup(self.alpha_embedding, self.batch_features[:, 0]) | ||
alpha_embed_expand = tf.expand_dims(alpha_embed, 1) | ||
alpha_i_sum = tf.reduce_sum(tf.exp(alpha_embed_expand), axis=-1) | ||
merge_emb = tf.reduce_sum(stack_embed * tf.exp(alpha_embed_expand), axis=-1) / alpha_i_sum | ||
return merge_emb | ||
|
||
def make_skipgram_loss(self, labels): | ||
loss = tf.reduce_mean(tf.nn.sampled_softmax_loss( | ||
weights=self.nce_w, | ||
biases=self.nce_b, | ||
labels=tf.expand_dims(labels, axis=1), | ||
inputs=self.merge_emb, | ||
num_sampled=self.n_samped, | ||
num_classes=self.num_nodes)) | ||
|
||
return loss | ||
|
||
def call(self, side_info, batch_index, batch_labels): | ||
self.side_info = tf.convert_to_tensor(side_info) | ||
self.batch_features = tf.nn.embedding_lookup(self.side_info, batch_index) | ||
|
||
embed_list = [] | ||
for i in range(self.num_feat): | ||
cat_embed = tf.nn.embedding_lookup(self.cat_embedding[i], self.batch_features[:, i]) | ||
embed_list.append(cat_embed) | ||
stack_embed = tf.stack(embed_list, axis=-1) | ||
# attention merge | ||
alpha_embed = tf.nn.embedding_lookup(self.alpha_embedding, self.batch_features[:, 0]) | ||
alpha_embed_expand = tf.expand_dims(alpha_embed, 1) | ||
alpha_i_sum = tf.reduce_sum(tf.exp(alpha_embed_expand), axis=-1) | ||
self.merge_emb = tf.reduce_sum(stack_embed * tf.exp(alpha_embed_expand), axis=-1) / alpha_i_sum | ||
|
||
return self.make_skipgram_loss(batch_labels) | ||
|
||
def get_embedding(self, batch_index): | ||
self.batch_features = tf.nn.embedding_lookup(self.side_info, batch_index) | ||
|
||
embed_list = [] | ||
for i in range(self.num_feat): | ||
cat_embed = tf.nn.embedding_lookup(self.cat_embedding[i], self.batch_features[:, i]) | ||
embed_list.append(cat_embed) | ||
stack_embed = tf.stack(embed_list, axis=-1) | ||
# attention merge | ||
alpha_embed = tf.nn.embedding_lookup(self.alpha_embedding, self.batch_features[:, 0]) | ||
alpha_embed_expand = tf.expand_dims(alpha_embed, 1) | ||
alpha_i_sum = tf.reduce_sum(tf.exp(alpha_embed_expand), axis=-1) | ||
merge_emb = tf.reduce_sum(stack_embed * tf.exp(alpha_embed_expand), axis=-1) / alpha_i_sum | ||
return merge_emb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# -*- coding:utf-8 -*- | ||
# @Time : 2021/9/12 12:11 下午 | ||
# @Author : huichuan LI | ||
# @File : run_eges.py | ||
# @Software: PyCharm | ||
|
||
|
||
import pandas as pd | ||
import numpy as np | ||
import tensorflow as tf | ||
import time | ||
import argparse | ||
|
||
from eges import EGES_Model | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='manual to this script') | ||
parser.add_argument("--batch_size", type=int, default=2048) | ||
parser.add_argument("--n_sampled", type=int, default=10) | ||
parser.add_argument("--epochs", type=int, default=1) | ||
parser.add_argument("--lr", type=float, default=0.001) | ||
parser.add_argument("--root_path", type=str, default='./data_cache/') | ||
parser.add_argument("--num_feat", type=int, default=4) | ||
parser.add_argument("--embedding_dim", type=int, default=128) | ||
parser.add_argument("--outputEmbedFile", type=str, default='./embedding/EGES.embed') | ||
args = parser.parse_args() | ||
|
||
# read train_data | ||
print('read features...') | ||
start_time = time.time() | ||
side_info = np.loadtxt(args.root_path + 'sku_side_info.csv', dtype=np.int32, delimiter='\t') | ||
feature_lens = [] | ||
for i in range(side_info.shape[1]): | ||
tmp_len = len(set(side_info[:, i])) | ||
feature_lens.append(tmp_len) | ||
end_time = time.time() | ||
print('time consumed for read features: %.2f' % (end_time - start_time)) | ||
|
||
|
||
# read data_pair by tf.dataset | ||
def decode_data_pair(line): | ||
columns = tf.strings.split([line], ' ') | ||
x = tf.strings.to_number(columns.values[0], out_type=tf.int32) | ||
y = tf.strings.to_number(columns.values[1], out_type=tf.int32) | ||
return x, y | ||
|
||
|
||
dataset = tf.data.TextLineDataset(args.root_path + 'all_pairs').map(decode_data_pair, | ||
num_parallel_calls=tf.data.AUTOTUNE).prefetch( | ||
500000) | ||
# dataset = dataset.shuffle(256) | ||
dataset = dataset.repeat(args.epochs) | ||
dataset = dataset.batch(args.batch_size) # Batch size to use | ||
iterator = tf.compat.v1.data.make_one_shot_iterator( | ||
dataset | ||
) | ||
|
||
print('read embedding...') | ||
start_time = time.time() | ||
EGES = EGES_Model(len(side_info), args.num_feat, feature_lens, | ||
n_sampled=args.n_sampled, embedding_dim=args.embedding_dim, lr=args.lr) | ||
end_time = time.time() | ||
print('time consumed for read embedding: %.2f' % (end_time - start_time)) | ||
opt = tf.keras.optimizers.Adam(0.01) | ||
print_every_k_iterations = 100 | ||
iteration = 0 | ||
start = time.time() | ||
while iterator: | ||
iteration += 1 | ||
|
||
batch_index, batch_labels = iterator.get_next() | ||
with tf.GradientTape() as tape: | ||
loss = EGES(side_info, batch_index, batch_labels) | ||
gradients = tape.gradient(loss, EGES.trainable_variables) | ||
opt.apply_gradients(zip(gradients, EGES.trainable_variables)) | ||
# 计算梯度 | ||
# 根据梯度值更新参数值 | ||
if iteration % print_every_k_iterations == 0: | ||
end = time.time() | ||
print("Iteration: {}".format(iteration), | ||
"Avg. Training loss: {:.4f}".format(loss / print_every_k_iterations), | ||
"{:.4f} sec/batch".format((end - start) / print_every_k_iterations)) | ||
start = time.time() | ||
|
||
|
||
|
||
print(EGES.get_embedding(side_info[:, 0])) |
Oops, something went wrong.