diff --git a/DNNModels.py b/DNNModels.py index 928462e..d3ad5b8 100644 --- a/DNNModels.py +++ b/DNNModels.py @@ -79,11 +79,12 @@ def get_coefs(word, *arr): return word, np.asarray(arr, dtype='float32') class DNNModel: def __init__(self, X_train=None, Y_train=None, algo="CNN", embedding="fasttext", max_features=5000, maxlen=500, - embedding_size=300, load_from_file=None): + embedding_size=300, load_from_file=None, plot_file_name =None): self.max_features = max_features self.maxlen = maxlen self.embed_size = embedding_size self.embedding_type = embedding + self.plot=plot_file_name session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=4, inter_op_parallelism_threads=4) @@ -347,9 +348,8 @@ def _train(self, X_train, Y_train): X_train_vector = self._prepare_df(X_train) if model is not None: - #print("Type is:") - #print(type(X_train_vector), type(Y_train)) - + if self.plot is not None: + tf.keras.utils.plot_model(model, to_file=self.plot, show_shapes=True) ##convert Y_train to np.ndaaray Y_train = Y_train.values diff --git a/ToxiCR.py b/ToxiCR.py index 93440f8..59b7f62 100644 --- a/ToxiCR.py +++ b/ToxiCR.py @@ -151,17 +151,22 @@ def load_pretrained_model(self, filename): def get_model(self, X_train, Y_train, tuning=False): ALGO = self.ALGO + plot_file_name = "./architecture/" + ALGO + ".png" if (ALGO == "RF") | (ALGO == "GBT") | (ALGO == "SVM") | (ALGO == "DT") | (ALGO == "LR"): self.classifier_model = CLEModel(X_train=X_train, Y_train=Y_train, algo=self.ALGO, tuning=tuning) elif ALGO == "BERT": from TransformerModel import TransformerModel - self.classifier_model = TransformerModel(X_train=X_train, Y_train=Y_train) + self.classifier_model = TransformerModel(X_train=X_train, Y_train=Y_train, plot_file_name=plot_file_name) elif (ALGO == "CNN") | (ALGO == "LSTM") | (ALGO == "GRU") | (ALGO == "biLSTM") : import DNNModels + self.classifier_model = DNNModels.DNNModel(X_train=X_train, Y_train=Y_train, - algo=ALGO, + algo=ALGO, plot_file_name=plot_file_name, embedding=self.embedding) + else: + print("Unknown algorithm: "+ALGO) + exit(1) return self.classifier_model @@ -257,7 +262,7 @@ def ten_fold_cross_validation(toxicClassifier, rand_state): args = parser.parse_args() print(args) - ALGO = args.algo + ALGO = str(args.algo).upper() REPEAT = args.repeat embedding = args.embed mode = args.mode diff --git a/TransformerModel.py b/TransformerModel.py index 5f5a97a..2fa28a1 100644 --- a/TransformerModel.py +++ b/TransformerModel.py @@ -38,7 +38,8 @@ def df_to_dataset(dataframe, batch_size=16): class TransformerModel: def __init__(self, X_train=None, Y_train=None, - bert_model_name="bert_en_uncased_L-12_H-768_A-12", load_from_file=None): + bert_model_name="bert_en_uncased_L-12_H-768_A-12", load_from_file=None, + plot_file_name=None): self.tfhub_handle_encoder = BertLocator.getBERTEncoderURL(bert_model_name) self.tfhub_handle_preprocess = BertLocator.getPreprocessURL(bert_model_name) @@ -48,6 +49,7 @@ def __init__(self, X_train=None, Y_train=None, self.bert_preprocess_model = hub.KerasLayer(self.tfhub_handle_preprocess) self.bert_model = hub.KerasLayer(self.tfhub_handle_encoder) self.epochs = 20 + self.plot=plot_file_name if load_from_file is not None: self.steps_per_epoch = 19571 # size of our dataset @@ -77,6 +79,8 @@ def build_classifier_model(self): classifier_model.compile(optimizer=optimizer, loss=loss, metrics=metrics) + if self.plot is not None: + tf.keras.utils.plot_model(classifier_model, to_file=self.plot, show_shapes=True) return classifier_model def get_optimizer(self): diff --git a/cross-validations/comparisions/Cross-validations-CNN.xlsx b/cross-validations/comparisons/Cross-validations-CNN.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-CNN.xlsx rename to cross-validations/comparisons/Cross-validations-CNN.xlsx diff --git a/cross-validations/comparisions/Cross-validations-DT.xlsx b/cross-validations/comparisons/Cross-validations-DT.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-DT.xlsx rename to cross-validations/comparisons/Cross-validations-DT.xlsx diff --git a/cross-validations/comparisions/Cross-validations-GBC.xlsx b/cross-validations/comparisons/Cross-validations-GBC.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-GBC.xlsx rename to cross-validations/comparisons/Cross-validations-GBC.xlsx diff --git a/cross-validations/comparisions/Cross-validations-GRU.xlsx b/cross-validations/comparisons/Cross-validations-GRU.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-GRU.xlsx rename to cross-validations/comparisons/Cross-validations-GRU.xlsx diff --git a/cross-validations/comparisions/Cross-validations-LR.xlsx b/cross-validations/comparisons/Cross-validations-LR.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-LR.xlsx rename to cross-validations/comparisons/Cross-validations-LR.xlsx diff --git a/cross-validations/comparisions/Cross-validations-LSTM.xlsx b/cross-validations/comparisons/Cross-validations-LSTM.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-LSTM.xlsx rename to cross-validations/comparisons/Cross-validations-LSTM.xlsx diff --git a/cross-validations/comparisions/Cross-validations-RF.xlsx b/cross-validations/comparisons/Cross-validations-RF.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-RF.xlsx rename to cross-validations/comparisons/Cross-validations-RF.xlsx diff --git a/cross-validations/comparisions/Cross-validations-SVM.xlsx b/cross-validations/comparisons/Cross-validations-SVM.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-SVM.xlsx rename to cross-validations/comparisons/Cross-validations-SVM.xlsx diff --git a/cross-validations/comparisions/Cross-validations-biLSTM.xlsx b/cross-validations/comparisons/Cross-validations-biLSTM.xlsx similarity index 100% rename from cross-validations/comparisions/Cross-validations-biLSTM.xlsx rename to cross-validations/comparisons/Cross-validations-biLSTM.xlsx diff --git a/cross-validations/comparisions/cross-validations-BERT.xlsx b/cross-validations/comparisons/cross-validations-BERT.xlsx similarity index 100% rename from cross-validations/comparisions/cross-validations-BERT.xlsx rename to cross-validations/comparisons/cross-validations-BERT.xlsx diff --git a/cross-validations/comparisions/embedding_comparison-CNN.xlsx b/cross-validations/comparisons/embedding_comparison-CNN.xlsx similarity index 100% rename from cross-validations/comparisions/embedding_comparison-CNN.xlsx rename to cross-validations/comparisons/embedding_comparison-CNN.xlsx diff --git a/cross-validations/comparisions/embedding_comparison-GRU.xlsx b/cross-validations/comparisons/embedding_comparison-GRU.xlsx similarity index 100% rename from cross-validations/comparisions/embedding_comparison-GRU.xlsx rename to cross-validations/comparisons/embedding_comparison-GRU.xlsx diff --git a/cross-validations/comparisions/embedding_comparison-LSTM.xlsx b/cross-validations/comparisons/embedding_comparison-LSTM.xlsx similarity index 100% rename from cross-validations/comparisions/embedding_comparison-LSTM.xlsx rename to cross-validations/comparisons/embedding_comparison-LSTM.xlsx diff --git a/cross-validations/comparisions/embedding_comparison-biiLSTM.xlsx b/cross-validations/comparisons/embedding_comparison-biiLSTM.xlsx similarity index 100% rename from cross-validations/comparisions/embedding_comparison-biiLSTM.xlsx rename to cross-validations/comparisons/embedding_comparison-biiLSTM.xlsx diff --git a/model-architecture/BERT.png b/model-architecture/BERT.png new file mode 100644 index 0000000..e68e1a1 Binary files /dev/null and b/model-architecture/BERT.png differ diff --git a/model-architecture/CNN.png b/model-architecture/CNN.png new file mode 100644 index 0000000..1c92f5d Binary files /dev/null and b/model-architecture/CNN.png differ diff --git a/model-architecture/GRU.png b/model-architecture/GRU.png new file mode 100644 index 0000000..663745e Binary files /dev/null and b/model-architecture/GRU.png differ diff --git a/model-architecture/LSTM.png b/model-architecture/LSTM.png new file mode 100644 index 0000000..ea91538 Binary files /dev/null and b/model-architecture/LSTM.png differ diff --git a/model-architecture/biLSTM.png b/model-architecture/biLSTM.png new file mode 100644 index 0000000..fc79e61 Binary files /dev/null and b/model-architecture/biLSTM.png differ diff --git a/requirements.txt b/requirements.txt index d3b93d5..1d26dd0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ tensorflow-text == 2.5.0 tensorflow-hub>=0.12.0 nltk>=3.5 spacy>=2.3 -openpyxl>=3.0.9 \ No newline at end of file +openpyxl>=3.0.9 +pydot +graphviz