diff --git a/delft/utilities/Embeddings.py b/delft/utilities/Embeddings.py index 633330f3..47988f11 100644 --- a/delft/utilities/Embeddings.py +++ b/delft/utilities/Embeddings.py @@ -87,7 +87,7 @@ def __init__(self, name, if use_ELMo: self.make_ELMo() self.embed_size = ELMo_embed_size + self.embed_size - description = self._get_description('elmo-'+self.lang) + description = self.get_description('elmo-'+self.lang) self.env_ELMo = None if description and description["cache-training"] and self.use_cache: self.embedding_ELMo_cache = os.path.join(description["path-cache"], "cache") @@ -106,7 +106,7 @@ def __init__(self, name, #self.session.run(tf.global_variables_initializer()) self.make_BERT() self.embed_size = BERT_embed_size + self.embed_size - description = self._get_description('bert-base-'+self.lang) + description = self.get_description('bert-base-'+self.lang) self.env_BERT = None if description and description["cache-training"] and self.use_cache: self.embedding_BERT_cache = os.path.join(description["path-cache"], "cache") @@ -131,7 +131,7 @@ def make_embeddings_simple_in_memory(self, name="fasttext-crawl"): nbWords = 0 print('loading embeddings...') begin = True - description = self._get_description(name) + description = self.get_description(name) if description is not None: embeddings_path = description["path"] self.lang = description["lang"] @@ -168,7 +168,7 @@ def make_embeddings_simple_in_memory(self, name="fasttext-crawl"): def make_embeddings_lmdb(self, name="fasttext-crawl"): print('\nCompiling embeddings... (this is done only one time per embeddings at first usage)') - description = self._get_description(name) + description = self.get_description(name) if description is None: print('\nNo description found in embeddings registry for embeddings', name) @@ -195,7 +195,7 @@ def load_embeddings_from_file(self, embeddings_path): nb_lines = 0 # read number of lines first - embedding_file = _open_embedding_file(embeddings_path) + embedding_file = open_embedding_file(embeddings_path) if embedding_file is None: print("Error: could not open embeddings file", embeddings_path) return @@ -204,7 +204,7 @@ def load_embeddings_from_file(self, embeddings_path): nb_lines += 1 embedding_file.close() - embedding_file = _open_embedding_file(embeddings_path) + embedding_file = open_embedding_file(embeddings_path) #with open(embeddings_path, encoding='utf8') as f: for line in tqdm(embedding_file, total=nb_lines): line = line.decode() @@ -266,7 +266,7 @@ def clean_downloads(self): print('Failed to delete %s. Reason: %s' % (file_path, e)) def make_embeddings_simple(self, name="fasttext-crawl"): - description = self._get_description(name) + description = self.get_description(name) if description is not None: self.extension = description["format"] @@ -294,7 +294,7 @@ def make_embeddings_simple(self, name="fasttext-crawl"): envFilePath = os.path.join(self.embedding_lmdb_path, name) load_db = True if os.path.isdir(envFilePath): - description = self._get_description(name) + description = self.get_description(name) if description is not None: self.lang = description["lang"] @@ -333,7 +333,7 @@ def make_embeddings_simple(self, name="fasttext-crawl"): def make_ELMo(self): # Location of pretrained BiLM for the specified language # TBD check if ELMo language resources are present - description = self._get_description('elmo-'+self.lang) + description = self.get_description('elmo-'+self.lang) if description is not None: self.lang = description["lang"] vocab_file = description["path-vocab"] @@ -358,7 +358,7 @@ def make_ELMo(self): def make_BERT(self): # Location of BERT model - description = self._get_description('bert-base-'+self.lang) + description = self.get_description('bert-base-'+self.lang) if description is not None: self.lang = description["lang"] config_file = description["path-config"] @@ -570,7 +570,7 @@ def get_sentence_vector_with_BERT(self, token_list): return concatenated_squeezed_result - def _get_description(self, name): + def get_description(self, name): for emb in self.registry["embeddings"]: if emb["name"] == name: return emb @@ -824,7 +824,7 @@ def _serialize_pickle(a): def _deserialize_pickle(serialized): return pickle.loads(serialized) -def _open_embedding_file(embeddings_path): +def open_embedding_file(embeddings_path): # embeddings can be uncompressed or compressed with gzip or zip if embeddings_path.endswith(".gz"): embedding_file = gzip.open(embeddings_path, mode="rt")