Skip to content

Commit

Permalink
refactor a bit embedding methods for docker preload
Browse files Browse the repository at this point in the history
  • Loading branch information
kermitt2 committed Dec 23, 2020
1 parent d88804a commit 85c9b78
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 85 deletions.
164 changes: 89 additions & 75 deletions delft/utilities/Embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,14 @@

class Embeddings(object):

def __init__(self, name, path='./embedding-registry.json', lang='en', extension='vec', use_ELMo=False, use_BERT=False, use_cache=True):
def __init__(self, name,
path='./embedding-registry.json',
lang='en',
extension='vec',
use_ELMo=False,
use_BERT=False,
use_cache=True,
load=True):
self.name = name
self.embed_size = 0
self.static_embed_size = 0
Expand All @@ -69,7 +76,8 @@ def __init__(self, name, path='./embedding-registry.json', lang='en', extension=
if self.registry is not None:
self.embedding_lmdb_path = self.registry["embedding-lmdb-path"]
self.env = None
self.make_embeddings_simple(name)
if load:
self.make_embeddings_simple(name)
self.static_embed_size = self.embed_size
self.bilm = None

Expand Down Expand Up @@ -159,9 +167,7 @@ def make_embeddings_simple_in_memory(self, name="fasttext-crawl"):
print('embeddings loaded for', nbWords, "words and", self.embed_size, "dimensions")

def make_embeddings_lmdb(self, name="fasttext-crawl"):
nbWords = 0
print('\nCompiling embeddings... (this is done only one time per embeddings at first usage)')
begin = True
description = self._get_description(name)

if description is None:
Expand All @@ -175,81 +181,89 @@ def make_embeddings_lmdb(self, name="fasttext-crawl"):
print('\nCould not locate a usable resource for embeddings', name)
return

txn = self.env.begin(write=True)
# batch_size = 1024
i = 0
nb_lines = 0
self.load_embeddings_from_file(embeddings_path)

# read number of lines first
embedding_file = _open_embedding_file(embeddings_path)
if embedding_file is None:
print("Error: could not open embeddings file", embeddings_path)
return
# cleaning possible downloaded embeddings
self.clean_downloads()

for line in embedding_file:
nb_lines += 1
embedding_file.close()

embedding_file = _open_embedding_file(embeddings_path)
#with open(embeddings_path, encoding='utf8') as f:
for line in tqdm(embedding_file, total=nb_lines):
line = line.decode()
line = line.split(' ')
if begin:
begin = False
nb_words, embed_size = _fetch_header_if_available(line)

if nb_words > 0 and embed_size > 0:
nbWords = nb_words
self.embed_size = embed_size
continue

word = line[0]
try:
if line[len(line)-1] == '\n':
vector = np.array([float(val) for val in line[1:len(line)-1]], dtype='float32')
else:
vector = np.array([float(val) for val in line[1:len(line)]], dtype='float32')

#vector = np.array([float(val) for val in line[1:len(line)]], dtype='float32')
except:
print(len(line))
print(line[1:len(line)])
#else:
# vector = np.array([float(val) for val in line[1:len(line)-1]], dtype='float32')
if self.embed_size == 0:
self.embed_size = len(vector)

if len(word.encode(encoding='UTF-8')) < self.env.max_key_size():
txn.put(word.encode(encoding='UTF-8'), _serialize_pickle(vector))
#txn.put(word.encode(encoding='UTF-8'), _serialize_byteio(vector))
i += 1

# commit batch
# if i % batch_size == 0:
# txn.commit()
# txn = self.env.begin(write=True)

embedding_file.close()

#if i % batch_size != 0:
txn.commit()
if nbWords == 0:
nbWords = i
self.vocab_size = nbWords
print('embeddings loaded for', nbWords, "words and", self.embed_size, "dimensions")
def load_embeddings_from_file(self, embeddings_path):
begin = True
nbWords = 0
txn = self.env.begin(write=True)
# batch_size = 1024
i = 0
nb_lines = 0

# read number of lines first
embedding_file = _open_embedding_file(embeddings_path)
if embedding_file is None:
print("Error: could not open embeddings file", embeddings_path)
return

# cleaning possible downloaded embeddings
for filename in os.listdir(self.registry['embedding-download-path']):
file_path = os.path.join(self.registry['embedding-download-path'], filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))
for line in embedding_file:
nb_lines += 1
embedding_file.close()

embedding_file = _open_embedding_file(embeddings_path)
#with open(embeddings_path, encoding='utf8') as f:
for line in tqdm(embedding_file, total=nb_lines):
line = line.decode()
line = line.split(' ')
if begin:
begin = False
nb_words, embed_size = _fetch_header_if_available(line)

if nb_words > 0 and embed_size > 0:
nbWords = nb_words
self.embed_size = embed_size
continue

word = line[0]
try:
if line[len(line)-1] == '\n':
vector = np.array([float(val) for val in line[1:len(line)-1]], dtype='float32')
else:
vector = np.array([float(val) for val in line[1:len(line)]], dtype='float32')

#vector = np.array([float(val) for val in line[1:len(line)]], dtype='float32')
except:
print(len(line))
print(line[1:len(line)])
#else:
# vector = np.array([float(val) for val in line[1:len(line)-1]], dtype='float32')
if self.embed_size == 0:
self.embed_size = len(vector)

if len(word.encode(encoding='UTF-8')) < self.env.max_key_size():
txn.put(word.encode(encoding='UTF-8'), _serialize_pickle(vector))
#txn.put(word.encode(encoding='UTF-8'), _serialize_byteio(vector))
i += 1

# commit batch
# if i % batch_size == 0:
# txn.commit()
# txn = self.env.begin(write=True)

embedding_file.close()

#if i % batch_size != 0:
txn.commit()
if nbWords == 0:
nbWords = i
self.vocab_size = nbWords
print('embeddings loaded for', nbWords, "words and", self.embed_size, "dimensions")

def clean_downloads(self):
# cleaning possible downloaded embeddings
for filename in os.listdir(self.registry['embedding-download-path']):
file_path = os.path.join(self.registry['embedding-download-path'], filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path)
elif os.path.isdir(file_path):
shutil.rmtree(file_path)
except Exception as e:
print('Failed to delete %s. Reason: %s' % (file_path, e))

def make_embeddings_simple(self, name="fasttext-crawl"):
description = self._get_description(name)
Expand Down
12 changes: 2 additions & 10 deletions grobidTagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,7 @@ class Tasks:


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description = "Trainer for GROBID models")
parser = argparse.ArgumentParser(description = "Trainer for GROBID models")

actions = [Tasks.TRAIN, Tasks.TRAIN_EVAL, Tasks.EVAL, Tasks.TAG]
architectures = [BidLSTM_CRF.name, BidLSTM_CNN.name, BidLSTM_CNN_CRF.name, BidGRU_CRF.name, BidLSTM_CRF_CASING.name,
Expand All @@ -220,8 +219,7 @@ class Tasks:
"cross validation.")
parser.add_argument("--architecture", default='BidLSTM_CRF', choices=architectures,
help="Type of model architecture to be used, one of "+str(architectures))
parser.add_argument(
"--embedding", default='glove-840B',
parser.add_argument("--embedding", default='glove-840B',
help=(
"The desired pre-trained word embeddings using their descriptions in the file"
" embedding-registry.json."
Expand All @@ -233,7 +231,6 @@ class Tasks:
parser.add_argument("--output", help="Directory where to save a trained model.")
parser.add_argument("--input", help="Grobid data file to be used for training (train action), for trainng and "
"evaluation (train_eval action) or just for evaluation (eval action).")
# parser.add_argument("--ignore-features", default=False, action="store_true", help="Ignore layout features")
parser.add_argument(
"--feature-indices",
type=parse_number_ranges,
Expand All @@ -243,14 +240,9 @@ class Tasks:
args = parser.parse_args()

model = args.model
#if not model in models:
# print('invalid model, should be one of', models)

action = args.action

use_ELMo = args.use_ELMo
architecture = args.architecture

output = args.output
input_path = args.input
embeddings_name = args.embedding
Expand Down

0 comments on commit 85c9b78

Please sign in to comment.