From 2d1ab9f73c2a6b5e646075538e54371672e09a87 Mon Sep 17 00:00:00 2001 From: huichuan Date: Sun, 12 Sep 2021 23:23:13 +0800 Subject: [PATCH] add --- EGES/eges.py | 2 +- EGES/run_eges.py | 2 +- Line/TF/Line.ipynb | 2 +- Seq2seq/cnn-lm.py | 130 +++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 Seq2seq/cnn-lm.py diff --git a/EGES/eges.py b/EGES/eges.py index 44f9852..6954896 100644 --- a/EGES/eges.py +++ b/EGES/eges.py @@ -48,6 +48,7 @@ def attention_merge(self): cat_embed = tf.nn.embedding_lookup(self.cat_embedding[i], self.batch_features[:, i]) embed_list.append(cat_embed) stack_embed = tf.stack(embed_list, axis=-1) + print(stack_embed) # attention merge alpha_embed = tf.nn.embedding_lookup(self.alpha_embedding, self.batch_features[:, 0]) alpha_embed_expand = tf.expand_dims(alpha_embed, 1) @@ -80,7 +81,6 @@ def call(self, side_info, batch_index, batch_labels): alpha_embed_expand = tf.expand_dims(alpha_embed, 1) alpha_i_sum = tf.reduce_sum(tf.exp(alpha_embed_expand), axis=-1) self.merge_emb = tf.reduce_sum(stack_embed * tf.exp(alpha_embed_expand), axis=-1) / alpha_i_sum - return self.make_skipgram_loss(batch_labels) def get_embedding(self, batch_index): diff --git a/EGES/run_eges.py b/EGES/run_eges.py index 223a938..a4ad648 100644 --- a/EGES/run_eges.py +++ b/EGES/run_eges.py @@ -82,6 +82,6 @@ def decode_data_pair(line): "{:.4f} sec/batch".format((end - start) / print_every_k_iterations)) start = time.time() - + print(EGES.get_embedding(side_info[:, 0])) diff --git a/Line/TF/Line.ipynb b/Line/TF/Line.ipynb index f7c5a40..8ac6d21 100644 --- a/Line/TF/Line.ipynb +++ b/Line/TF/Line.ipynb @@ -1005,4 +1005,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/Seq2seq/cnn-lm.py b/Seq2seq/cnn-lm.py new file mode 100644 index 0000000..2bfe474 --- /dev/null +++ b/Seq2seq/cnn-lm.py @@ -0,0 +1,130 @@ +# -*- coding:utf-8 -*- +# @Time : 2021/9/12 10:01 下午 +# @Author : huichuan LI +# @File : cnn-lm.py +# @Software: PyCharm + +import tensorflow as tf +from tensorflow import keras +import numpy as np +import utils +import tensorflow_addons as tfa + + +class CNNTranslation(keras.Model): + def __init__(self, enc_v_dim, dec_v_dim, emb_dim, units, max_pred_len, start_token, end_token): + super().__init__() + self.units = units + + # encoder + self.enc_embeddings = keras.layers.Embedding( + input_dim=enc_v_dim, output_dim=emb_dim, # [enc_n_vocab, emb_dim] + embeddings_initializer=tf.initializers.RandomNormal(0., 0.1), + ) + self.conv2ds = [ + keras.layers.Conv2D(16, (n, emb_dim), padding="valid", activation=keras.activations.relu) + for n in range(2, 5)] + self.max_pools = [keras.layers.MaxPool2D((n, 1)) for n in [7, 6, 5]] + self.encoder = keras.layers.Dense(units, activation=keras.activations.relu) + + # decoder + self.dec_embeddings = keras.layers.Embedding( + input_dim=dec_v_dim, output_dim=emb_dim, # [dec_n_vocab, emb_dim] + embeddings_initializer=tf.initializers.RandomNormal(0., 0.1), + ) + self.decoder_cell = keras.layers.LSTMCell(units=units) + decoder_dense = keras.layers.Dense(dec_v_dim) + # train decoder + self.decoder_train = tfa.seq2seq.BasicDecoder( + cell=self.decoder_cell, + sampler=tfa.seq2seq.sampler.TrainingSampler(), # sampler for train + output_layer=decoder_dense + ) + # predict decoder + self.decoder_eval = tfa.seq2seq.BasicDecoder( + cell=self.decoder_cell, + sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler(), # sampler for predict + output_layer=decoder_dense + ) + + self.cross_entropy = keras.losses.SparseCategoricalCrossentropy(from_logits=True) + self.opt = keras.optimizers.Adam(0.01) + self.max_pred_len = max_pred_len + self.start_token = start_token + self.end_token = end_token + + def encode(self, x): + embedded = self.enc_embeddings(x) # [n, step, emb] + o = tf.expand_dims(embedded, axis=3) # [n, step=8, emb=16, 1] + co = [conv2d(o) for conv2d in self.conv2ds] # [n, 7, 1, 16], [n, 6, 1, 16], [n, 5, 1, 16] + co = [self.max_pools[i](co[i]) for i in range(len(co))] # [n, 1, 1, 16] * 3 + co = [tf.squeeze(c, axis=[1, 2]) for c in co] # [n, 16] * 3 + o = tf.concat(co, axis=1) # [n, 16*3] + h = self.encoder(o) # [n, units] + return [h, h] + + def inference(self, x): + s = self.encode(x) + done, i, s = self.decoder_eval.initialize( + self.dec_embeddings.variables[0], + start_tokens=tf.fill([x.shape[0], ], self.start_token), + end_token=self.end_token, + initial_state=s, + ) + pred_id = np.zeros((x.shape[0], self.max_pred_len), dtype=np.int32) + for l in range(self.max_pred_len): + o, s, i, done = self.decoder_eval.step( + time=l, inputs=i, state=s, training=False) + pred_id[:, l] = o.sample_id + return pred_id + + def train_logits(self, x, y, seq_len): + s = self.encode(x) + dec_in = y[:, :-1] # ignore + dec_emb_in = self.dec_embeddings(dec_in) + o, _, _ = self.decoder_train(dec_emb_in, s, sequence_length=seq_len) + logits = o.rnn_output + return logits + + def step(self, x, y, seq_len): + with tf.GradientTape() as tape: + logits = self.train_logits(x, y, seq_len) + dec_out = y[:, 1:] # ignore + loss = self.cross_entropy(dec_out, logits) + grads = tape.gradient(loss, self.trainable_variables) + self.opt.apply_gradients(zip(grads, self.trainable_variables)) + return loss.numpy() + + +def train(): + # get and process data + data = utils.DateData(4000) + print("Chinese time order: yy/mm/dd ", data.date_cn[:3], "\nEnglish time order: dd/M/yyyy ", data.date_en[:3]) + print("vocabularies: ", data.vocab) + print("x index sample: \n{}\n{}".format(data.idx2str(data.x[0]), data.x[0]), + "\ny index sample: \n{}\n{}".format(data.idx2str(data.y[0]), data.y[0])) + + model = CNNTranslation( + data.num_word, data.num_word, emb_dim=16, units=32, + max_pred_len=11, start_token=data.start_token, end_token=data.end_token) + + # training + for t in range(1500): + bx, by, decoder_len = data.sample(32) + loss = model.step(bx, by, decoder_len) + if t % 70 == 0: + target = data.idx2str(by[0, 1:-1]) + pred = model.inference(bx[0:1]) + res = data.idx2str(pred[0]) + src = data.idx2str(bx[0]) + print( + "t: ", t, + "| loss: %.3f" % loss, + "| input: ", src, + "| target: ", target, + "| inference: ", res, + ) + + +if __name__ == "__main__": + train()