-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7234ca6
commit 3e1f9f1
Showing
6 changed files
with
269 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
""" | ||
Configurations | ||
""" | ||
from utils import StrMessages | ||
|
||
class config(StrMessages): | ||
MODEL_NAME = "gpt_maximal_00" | ||
INPUT_LENGTH = 128 | ||
DEPTH = 512 | ||
HEADS = 4 | ||
FF_NODES = 1024 | ||
N_LAYERS = 4 | ||
|
||
SHOW_LOSS_HISTORY = True | ||
|
||
CORPUS_URL = "https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
""" | ||
Inference | ||
""" | ||
import numpy as np | ||
import tensorflow as tf | ||
import maximal | ||
|
||
from config import config | ||
from models import load_or_build_model | ||
|
||
|
||
def nlg(): | ||
print("Loading model...") | ||
|
||
gpt = load_or_build_model() | ||
|
||
print("Completed.") | ||
|
||
print(config.MSG_GREETINGS) | ||
|
||
while true: | ||
prompt = input("User: ") | ||
|
||
if prompt < config.INPUT_LENGTH: | ||
print(f"Please provide a prompt of {config.INPUT_LENGTH}") | ||
|
||
# If prompt too short send a shakespearean message | ||
print(config.MSG_INPUT_TOO_SHORT.format(config.INPUT_LENGTH)) | ||
continue | ||
|
||
generated_text = generate_text(prompt, config) | ||
print(f"\nShakespeare-GPT: {generated_text}\n") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
""" | ||
Building GPT architecture | ||
""" | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.keras.layers import Input, Dense | ||
from tensorflow.keras.models import Model | ||
import maximal | ||
from maximal.layers import ( | ||
PositionalEmbedding, GPTLayer | ||
) | ||
|
||
from config import config | ||
|
||
|
||
def build_model(): | ||
""" | ||
Builds a GPT using Maximal and TensorFlow. | ||
Args: / (just needs config params) | ||
Returns: GPT model (tf.keras.models.Model) | ||
""" | ||
# Define nodes of the graph | ||
input_batch = Input(shape=(INPUT_LENGTH,), dtype=tf.int32) | ||
embedding = PositionalEmbedding(INPUT_LENGTH, VOCAB_SIZE, DEPTH) | ||
gpt_layers = [GPTLayer(depth=DEPTH, heads=HEADS, ff_nodes=FF_NODES) for _ in range(N_LAYERS)] | ||
classification_layer = Dense(VOCAB_SIZE) | ||
|
||
# Build the computational graph | ||
x = embedding(input_batch) | ||
|
||
for layer in gpt_layers: | ||
x = layer(x) | ||
|
||
classification = classification_layer(x) | ||
|
||
return Model( | ||
inputs=input_batch, | ||
outputs=classification | ||
) | ||
|
||
|
||
def load_model(): | ||
""" | ||
If a model with a given name already exists | ||
:return: | ||
""" | ||
return gpt | ||
|
||
|
||
def load_or_build_model(): | ||
|
||
# check if the model is | ||
|
||
# | ||
|
||
return gpt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
numpy | ||
tensorflow>2.1 | ||
maximal>=1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
""" | ||
Training | ||
""" | ||
import requests | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
import matplotlib.pyplot as plt | ||
|
||
from config import config | ||
from model import load_or_build_model | ||
|
||
|
||
# globals | ||
gpt = load_or_build_model() | ||
optimizer = tf.keras.optimizers.Adam(learning_rate=config.LEARNING_RATE) | ||
|
||
|
||
def numerical_encoding(text, char_dict): | ||
""" | ||
First breaks text into a list of chars, then converts each to | ||
its numerical idx (np.array) | ||
""" | ||
chars_list = [ char for char in text ] | ||
chars_list = [ char_dict[char] for char in chars_list ] | ||
chars_list = np.array(chars_list) | ||
return chars_list | ||
|
||
|
||
def get_text_matrix(sequence, len_input): | ||
""" | ||
This generates a matrix containing all the sequences | ||
of length INPUT_LENGTH to be fed into the Network | ||
""" | ||
# create empty matrix | ||
X = np.empty((len(sequence)-len_input, len_input)) | ||
|
||
# fill each row/time window from input sequence | ||
for i in range(X.shape[0]): | ||
X[i,:] = sequence[i : i+len_input] | ||
|
||
return X | ||
|
||
|
||
def process_corpus(): | ||
page = requests.get(config.CORPUS_URL) | ||
text = page.text | ||
|
||
# Store list of unique characters | ||
unique_chars = list(set(text)) | ||
unique_chars.sort() | ||
|
||
# Map every letter in our alphabet to an int | ||
char2idx = {char[1]: char[0] for char in enumerate(unique_chars)} | ||
|
||
# Produce a reverse dictionary to go back from int to str later | ||
idx2char = {v: k for k, v in char2idx.items()} | ||
|
||
encoded_text = numerical_encoding(text, char2idx) | ||
|
||
X = get_text_matrix(encoded_text, INPUT_LENGTH + 1) | ||
|
||
return X | ||
|
||
|
||
@tf.function | ||
def train_on_batch(x, y): | ||
with tf.GradientTape() as tape: | ||
|
||
batch_loss = tf.reduce_sum( | ||
tf.keras.losses.sparse_categorical_crossentropy( | ||
y, gpt(x), | ||
from_logits=True) | ||
) | ||
|
||
gradients = tape.gradient(batch_loss, gpt.trainable_variables) | ||
optimizer.apply_gradients(zip(gradients, gpt.trainable_variables)) | ||
return batch_loss | ||
|
||
|
||
def main(): | ||
X = process_corpus() | ||
|
||
loss_history = [] | ||
|
||
for epoch in range(config.N_EPOCHS): | ||
start = time.time() | ||
|
||
# Reshuffle data at each epoch to randomize mini-batch composition | ||
reshuffle = np.random.choice(X.shape[0], X.shape[0], replace=False) | ||
X = X[reshuffle] | ||
|
||
for iteration in range(X.shape[0] // config.BATCH_SIZE): | ||
|
||
# take new minibatch (with 1 char shift from x to y) | ||
take = iteration * config.BATCH_SIZE | ||
x = X[take:take + config.BATCH_SIZE, :-1] # chars [0:128] | ||
y = X[take:take + config.BATCH_SIZE, 1:] # chars [1:129] | ||
|
||
# training step | ||
current_loss = train_on_batch(x, y) | ||
|
||
# periodically store batch loss into history | ||
if iteration % 100 == 0: | ||
loss_history.append(current_loss) | ||
print(f"\t{iteration}\tLoss: {current_loss}") | ||
|
||
print("{}. \t Loss: {} \t Time: {}ss".format( | ||
epoch + 1, current_loss.numpy(), round(time.time() - start, 2))) | ||
|
||
|
||
if config.SHOW_LOSS_HISTORY: | ||
# Visualize Loss history | ||
plt.figure(figsize=(15, 7)) | ||
plt.plot(loss_history) | ||
plt.title('Loss History') | ||
plt.xlabel('Iterations') | ||
plt.ylabel('Loss (Sparse CCE)') | ||
plt.show() | ||
|
||
# Save model | ||
gpt.save(os.path.join(os.getcwd(), "saved_models", config.MODEL_NAME)) | ||
|
||
return None | ||
|
||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
""" | ||
Messages to be displayed in CLI | ||
""" | ||
|
||
|
||
class StrMessages: | ||
MSG_GREETINGS = """ | ||
Hark, good sir or madam, I am none other than Shakespeare-GPT, | ||
A wondrous creation of language and technology. | ||
I, Shakespearean bot, beseech thee, take heed. | ||
With words of lofty prose, I shall adorn | ||
Thy journey through this CLI program's feed. | ||
Thou art most welcome to this humble stage, | ||
Where bytes and lines doth dance in harmony. | ||
Methinks thou seeketh knowledge of this age, | ||
And for thy query, I shall thee gladly see. | ||
Inscribe the word "exit," a concise decree, | ||
Or wield the key combination, Ctrl-C, with glee. | ||
By this act, thou shalt gracefully conclude thy stay, | ||
And from this program's realm, thou may swiftly stray. | ||
""" | ||
|
||
MSG_INPUT_TOO_SHORT = """ | ||
Pray, kind user, if it be not too much to ask, | ||
I beseech thee, extend thy prompt, a greater task. | ||
Yet, one more thing I must humbly request, | ||
A length of {} characters, at its behest. | ||
""" | ||
|