Skip to content

Commit

Permalink
Add model architecture diagrams
Browse files Browse the repository at this point in the history
  • Loading branch information
amiangshu committed Feb 26, 2022
1 parent 1486cb7 commit e75868b
Show file tree
Hide file tree
Showing 23 changed files with 20 additions and 9 deletions.
8 changes: 4 additions & 4 deletions DNNModels.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,11 +79,12 @@ def get_coefs(word, *arr): return word, np.asarray(arr, dtype='float32')
class DNNModel:
def __init__(self, X_train=None, Y_train=None, algo="CNN", embedding="fasttext",
max_features=5000, maxlen=500,
embedding_size=300, load_from_file=None):
embedding_size=300, load_from_file=None, plot_file_name =None):
self.max_features = max_features
self.maxlen = maxlen
self.embed_size = embedding_size
self.embedding_type = embedding
self.plot=plot_file_name


session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=4, inter_op_parallelism_threads=4)
Expand Down Expand Up @@ -347,9 +348,8 @@ def _train(self, X_train, Y_train):
X_train_vector = self._prepare_df(X_train)

if model is not None:
#print("Type is:")
#print(type(X_train_vector), type(Y_train))

if self.plot is not None:
tf.keras.utils.plot_model(model, to_file=self.plot, show_shapes=True)
##convert Y_train to np.ndaaray
Y_train = Y_train.values

Expand Down
11 changes: 8 additions & 3 deletions ToxiCR.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,17 +151,22 @@ def load_pretrained_model(self, filename):

def get_model(self, X_train, Y_train, tuning=False):
ALGO = self.ALGO
plot_file_name = "./architecture/" + ALGO + ".png"
if (ALGO == "RF") | (ALGO == "GBT") | (ALGO == "SVM") | (ALGO == "DT") | (ALGO == "LR"):
self.classifier_model = CLEModel(X_train=X_train, Y_train=Y_train, algo=self.ALGO, tuning=tuning)
elif ALGO == "BERT":
from TransformerModel import TransformerModel
self.classifier_model = TransformerModel(X_train=X_train, Y_train=Y_train)
self.classifier_model = TransformerModel(X_train=X_train, Y_train=Y_train, plot_file_name=plot_file_name)
elif (ALGO == "CNN") | (ALGO == "LSTM") | (ALGO == "GRU") | (ALGO == "biLSTM") :
import DNNModels

self.classifier_model = DNNModels.DNNModel(X_train=X_train,
Y_train=Y_train,
algo=ALGO,
algo=ALGO, plot_file_name=plot_file_name,
embedding=self.embedding)
else:
print("Unknown algorithm: "+ALGO)
exit(1)

return self.classifier_model

Expand Down Expand Up @@ -257,7 +262,7 @@ def ten_fold_cross_validation(toxicClassifier, rand_state):
args = parser.parse_args()

print(args)
ALGO = args.algo
ALGO = str(args.algo).upper()
REPEAT = args.repeat
embedding = args.embed
mode = args.mode
Expand Down
6 changes: 5 additions & 1 deletion TransformerModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def df_to_dataset(dataframe, batch_size=16):

class TransformerModel:
def __init__(self, X_train=None, Y_train=None,
bert_model_name="bert_en_uncased_L-12_H-768_A-12", load_from_file=None):
bert_model_name="bert_en_uncased_L-12_H-768_A-12", load_from_file=None,
plot_file_name=None):
self.tfhub_handle_encoder = BertLocator.getBERTEncoderURL(bert_model_name)
self.tfhub_handle_preprocess = BertLocator.getPreprocessURL(bert_model_name)

Expand All @@ -48,6 +49,7 @@ def __init__(self, X_train=None, Y_train=None,
self.bert_preprocess_model = hub.KerasLayer(self.tfhub_handle_preprocess)
self.bert_model = hub.KerasLayer(self.tfhub_handle_encoder)
self.epochs = 20
self.plot=plot_file_name

if load_from_file is not None:
self.steps_per_epoch = 19571 # size of our dataset
Expand Down Expand Up @@ -77,6 +79,8 @@ def build_classifier_model(self):
classifier_model.compile(optimizer=optimizer,
loss=loss,
metrics=metrics)
if self.plot is not None:
tf.keras.utils.plot_model(classifier_model, to_file=self.plot, show_shapes=True)
return classifier_model

def get_optimizer(self):
Expand Down
Binary file added model-architecture/BERT.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model-architecture/CNN.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model-architecture/GRU.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model-architecture/LSTM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added model-architecture/biLSTM.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@ tensorflow-text == 2.5.0
tensorflow-hub>=0.12.0
nltk>=3.5
spacy>=2.3
openpyxl>=3.0.9
openpyxl>=3.0.9
pydot
graphviz

0 comments on commit e75868b

Please sign in to comment.