From 85c9b789a76b117ef4ec147ae6a07bc4ee5d65ed Mon Sep 17 00:00:00 2001 From: lopez Date: Wed, 23 Dec 2020 12:01:20 +0100 Subject: [PATCH] refactor a bit embedding methods for docker preload --- delft/utilities/Embeddings.py | 164 ++++++++++++++++++---------------- grobidTagger.py | 12 +-- 2 files changed, 91 insertions(+), 85 deletions(-) diff --git a/delft/utilities/Embeddings.py b/delft/utilities/Embeddings.py index af73c17f..633330f3 100644 --- a/delft/utilities/Embeddings.py +++ b/delft/utilities/Embeddings.py @@ -56,7 +56,14 @@ class Embeddings(object): - def __init__(self, name, path='./embedding-registry.json', lang='en', extension='vec', use_ELMo=False, use_BERT=False, use_cache=True): + def __init__(self, name, + path='./embedding-registry.json', + lang='en', + extension='vec', + use_ELMo=False, + use_BERT=False, + use_cache=True, + load=True): self.name = name self.embed_size = 0 self.static_embed_size = 0 @@ -69,7 +76,8 @@ def __init__(self, name, path='./embedding-registry.json', lang='en', extension= if self.registry is not None: self.embedding_lmdb_path = self.registry["embedding-lmdb-path"] self.env = None - self.make_embeddings_simple(name) + if load: + self.make_embeddings_simple(name) self.static_embed_size = self.embed_size self.bilm = None @@ -159,9 +167,7 @@ def make_embeddings_simple_in_memory(self, name="fasttext-crawl"): print('embeddings loaded for', nbWords, "words and", self.embed_size, "dimensions") def make_embeddings_lmdb(self, name="fasttext-crawl"): - nbWords = 0 print('\nCompiling embeddings... (this is done only one time per embeddings at first usage)') - begin = True description = self._get_description(name) if description is None: @@ -175,81 +181,89 @@ def make_embeddings_lmdb(self, name="fasttext-crawl"): print('\nCould not locate a usable resource for embeddings', name) return - txn = self.env.begin(write=True) - # batch_size = 1024 - i = 0 - nb_lines = 0 + self.load_embeddings_from_file(embeddings_path) - # read number of lines first - embedding_file = _open_embedding_file(embeddings_path) - if embedding_file is None: - print("Error: could not open embeddings file", embeddings_path) - return + # cleaning possible downloaded embeddings + self.clean_downloads() - for line in embedding_file: - nb_lines += 1 - embedding_file.close() - - embedding_file = _open_embedding_file(embeddings_path) - #with open(embeddings_path, encoding='utf8') as f: - for line in tqdm(embedding_file, total=nb_lines): - line = line.decode() - line = line.split(' ') - if begin: - begin = False - nb_words, embed_size = _fetch_header_if_available(line) - - if nb_words > 0 and embed_size > 0: - nbWords = nb_words - self.embed_size = embed_size - continue - - word = line[0] - try: - if line[len(line)-1] == '\n': - vector = np.array([float(val) for val in line[1:len(line)-1]], dtype='float32') - else: - vector = np.array([float(val) for val in line[1:len(line)]], dtype='float32') - - #vector = np.array([float(val) for val in line[1:len(line)]], dtype='float32') - except: - print(len(line)) - print(line[1:len(line)]) - #else: - # vector = np.array([float(val) for val in line[1:len(line)-1]], dtype='float32') - if self.embed_size == 0: - self.embed_size = len(vector) - - if len(word.encode(encoding='UTF-8')) < self.env.max_key_size(): - txn.put(word.encode(encoding='UTF-8'), _serialize_pickle(vector)) - #txn.put(word.encode(encoding='UTF-8'), _serialize_byteio(vector)) - i += 1 - - # commit batch - # if i % batch_size == 0: - # txn.commit() - # txn = self.env.begin(write=True) - - embedding_file.close() - - #if i % batch_size != 0: - txn.commit() - if nbWords == 0: - nbWords = i - self.vocab_size = nbWords - print('embeddings loaded for', nbWords, "words and", self.embed_size, "dimensions") + def load_embeddings_from_file(self, embeddings_path): + begin = True + nbWords = 0 + txn = self.env.begin(write=True) + # batch_size = 1024 + i = 0 + nb_lines = 0 + + # read number of lines first + embedding_file = _open_embedding_file(embeddings_path) + if embedding_file is None: + print("Error: could not open embeddings file", embeddings_path) + return - # cleaning possible downloaded embeddings - for filename in os.listdir(self.registry['embedding-download-path']): - file_path = os.path.join(self.registry['embedding-download-path'], filename) - try: - if os.path.isfile(file_path) or os.path.islink(file_path): - os.unlink(file_path) - elif os.path.isdir(file_path): - shutil.rmtree(file_path) - except Exception as e: - print('Failed to delete %s. Reason: %s' % (file_path, e)) + for line in embedding_file: + nb_lines += 1 + embedding_file.close() + + embedding_file = _open_embedding_file(embeddings_path) + #with open(embeddings_path, encoding='utf8') as f: + for line in tqdm(embedding_file, total=nb_lines): + line = line.decode() + line = line.split(' ') + if begin: + begin = False + nb_words, embed_size = _fetch_header_if_available(line) + + if nb_words > 0 and embed_size > 0: + nbWords = nb_words + self.embed_size = embed_size + continue + word = line[0] + try: + if line[len(line)-1] == '\n': + vector = np.array([float(val) for val in line[1:len(line)-1]], dtype='float32') + else: + vector = np.array([float(val) for val in line[1:len(line)]], dtype='float32') + + #vector = np.array([float(val) for val in line[1:len(line)]], dtype='float32') + except: + print(len(line)) + print(line[1:len(line)]) + #else: + # vector = np.array([float(val) for val in line[1:len(line)-1]], dtype='float32') + if self.embed_size == 0: + self.embed_size = len(vector) + + if len(word.encode(encoding='UTF-8')) < self.env.max_key_size(): + txn.put(word.encode(encoding='UTF-8'), _serialize_pickle(vector)) + #txn.put(word.encode(encoding='UTF-8'), _serialize_byteio(vector)) + i += 1 + + # commit batch + # if i % batch_size == 0: + # txn.commit() + # txn = self.env.begin(write=True) + + embedding_file.close() + + #if i % batch_size != 0: + txn.commit() + if nbWords == 0: + nbWords = i + self.vocab_size = nbWords + print('embeddings loaded for', nbWords, "words and", self.embed_size, "dimensions") + + def clean_downloads(self): + # cleaning possible downloaded embeddings + for filename in os.listdir(self.registry['embedding-download-path']): + file_path = os.path.join(self.registry['embedding-download-path'], filename) + try: + if os.path.isfile(file_path) or os.path.islink(file_path): + os.unlink(file_path) + elif os.path.isdir(file_path): + shutil.rmtree(file_path) + except Exception as e: + print('Failed to delete %s. Reason: %s' % (file_path, e)) def make_embeddings_simple(self, name="fasttext-crawl"): description = self._get_description(name) diff --git a/grobidTagger.py b/grobidTagger.py index 9dea22fc..74c95c43 100644 --- a/grobidTagger.py +++ b/grobidTagger.py @@ -206,8 +206,7 @@ class Tasks: if __name__ == "__main__": - parser = argparse.ArgumentParser( - description = "Trainer for GROBID models") + parser = argparse.ArgumentParser(description = "Trainer for GROBID models") actions = [Tasks.TRAIN, Tasks.TRAIN_EVAL, Tasks.EVAL, Tasks.TAG] architectures = [BidLSTM_CRF.name, BidLSTM_CNN.name, BidLSTM_CNN_CRF.name, BidGRU_CRF.name, BidLSTM_CRF_CASING.name, @@ -220,8 +219,7 @@ class Tasks: "cross validation.") parser.add_argument("--architecture", default='BidLSTM_CRF', choices=architectures, help="Type of model architecture to be used, one of "+str(architectures)) - parser.add_argument( - "--embedding", default='glove-840B', + parser.add_argument("--embedding", default='glove-840B', help=( "The desired pre-trained word embeddings using their descriptions in the file" " embedding-registry.json." @@ -233,7 +231,6 @@ class Tasks: parser.add_argument("--output", help="Directory where to save a trained model.") parser.add_argument("--input", help="Grobid data file to be used for training (train action), for trainng and " "evaluation (train_eval action) or just for evaluation (eval action).") - # parser.add_argument("--ignore-features", default=False, action="store_true", help="Ignore layout features") parser.add_argument( "--feature-indices", type=parse_number_ranges, @@ -243,14 +240,9 @@ class Tasks: args = parser.parse_args() model = args.model - #if not model in models: - # print('invalid model, should be one of', models) - action = args.action - use_ELMo = args.use_ELMo architecture = args.architecture - output = args.output input_path = args.input embeddings_name = args.embedding