-
Notifications
You must be signed in to change notification settings - Fork 0
/
EmbeddingProcessor.py
120 lines (95 loc) · 4.48 KB
/
EmbeddingProcessor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import torch
import numpy as np
from torch.utils.data import DataLoader
from scipy.spatial.distance import cosine
import os
class EmbeddingProcessor:
"""
A class to handle mean pooling, batch embedding computation, similarity calculation,
and saving/loading embeddings.
Attributes:
tokenizer (AutoTokenizer): The tokenizer for the model.
model (AutoModel): The pre-trained model for generating embeddings.
"""
def __init__(self, tokenizer, model):
"""
Initializes the EmbeddingProcessor with a tokenizer and model.
Args:
tokenizer (AutoTokenizer): The tokenizer used to tokenize input texts.
model (AutoModel): The pre-trained model used to generate embeddings.
"""
self.tokenizer = tokenizer
self.model = model
def mean_pooling(self, token_embeddings, attention_mask):
"""
Perform mean pooling on token embeddings, weighted by the attention mask.
Args:
token_embeddings (Tensor): The token embeddings from the model.
attention_mask (Tensor): The attention mask for the input.
Returns:
Tensor: The pooled embeddings.
"""
token_embeddings = token_embeddings * attention_mask.unsqueeze(-1)
return token_embeddings.sum(dim=1) / attention_mask.sum(dim=1).unsqueeze(-1)
def batch_compute_embeddings(self, col_vals, batch_size=32):
"""
Compute embeddings for a list of values in batches.
Args:
col_vals (list): List of text values to compute embeddings for.
batch_size (int): The size of each batch for processing.
Returns:
dict: A dictionary mapping text to its computed embedding.
"""
embeddings = {}
dataloader = DataLoader(col_vals, batch_size=batch_size, shuffle=False)
for batch in dataloader:
inputs = self.tokenizer(batch, return_tensors="pt", truncation=True, padding=True, max_length=64)
with torch.no_grad():
outputs = self.model(**inputs)
batch_embeddings = self.mean_pooling(outputs.last_hidden_state, inputs['attention_mask'])
for i, title in enumerate(batch):
embeddings[title] = batch_embeddings[i].cpu().numpy()
return embeddings
def get_similar_items(self, ind_title, cached_embeddings):
"""
Compute similarity between a given industry title and cached embeddings.
Args:
ind_title (str): The industry title to compare against cached embeddings.
cached_embeddings (dict): A dictionary of precomputed embeddings.
Returns:
list: The top 10 most similar items from the cached embeddings.
"""
ind_inputs = self.tokenizer(ind_title, return_tensors="pt", truncation=True, padding=True, max_length=64)
with torch.no_grad():
ind_outputs = self.model(**ind_inputs)
ind_embedding = self.mean_pooling(ind_outputs.last_hidden_state, ind_inputs['attention_mask'])
ind_vector = ind_embedding.squeeze().cpu().numpy()
similarities = [
(title, 1 - cosine(ind_vector, occ_vector))
for title, occ_vector in cached_embeddings.items()
]
return sorted(similarities, key=lambda x: x[1], reverse=True)[:10]
def save_embeddings(self, embeddings, filename="embeddings_bert.npy"):
"""
Save computed embeddings to a file.
Args:
embeddings (dict): The dictionary of embeddings to save.
filename (str): The filename where embeddings will be saved.
"""
np.save(filename, embeddings)
def load_embeddings(self, string_list, batch_size=64, filename="embeddings_bert.npy"):
"""
Load embeddings from a file, or compute and save them if they don't exist.
Args:
string_list (list): The list of text to compute embeddings for.
batch_size (int): The size of each batch for processing.
filename (str): The filename to load or save embeddings.
Returns:
dict: The embeddings loaded or newly computed.
"""
if os.path.exists(filename):
embeddings = np.load(filename, allow_pickle=True).item()
else:
embeddings = self.batch_compute_embeddings(string_list, batch_size)
self.save_embeddings(embeddings, filename)
return embeddings