-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathembedding.py
307 lines (229 loc) · 10.6 KB
/
embedding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
"""
Classify speech acts from SBERT sentence embeddings.
"""
from . import base
import stanza.models.common.doc as doc
import speechact.annotate as anno
import speechact.corpus as corp
import speechact.preprocess as pre
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as tdat
import sentence_transformers as stf
from typing import Generator
import collections as col
from typing import Callable
NetworkFactory = Callable[[int, int], nn.Module]
SPEECH_ACTS = [
anno.SpeechActLabels.ASSERTION,
anno.SpeechActLabels.QUESTION,
anno.SpeechActLabels.DIRECTIVE,
anno.SpeechActLabels.EXPRESSIVE
]
"""
The speech act labels to classify. Note that the HYPOTHESIS is not included.
"""
class CorpusDataset(tdat.Dataset):
"""
A Pytorch compatible dataset for a speech act labeled Corpus.
All sentences are loaded into memory. However, it is only the sent_id, text, and speech_act
that is stored.
"""
def __init__(self, corpus: corp.Corpus) -> None:
super().__init__()
# Load sentences.
self.sentences = [anno.Sentence(s.text, str(s.sent_id), s.speech_act) for s in corpus.sentences()]
# Count class frequencies.
self.class_frequencies = col.Counter()
for sentence in self.sentences:
self.class_frequencies[sentence.label] += 1
def __len__(self) -> int:
return len(self.sentences)
def __getitem__(self, index) -> tuple[str, int]:
sentence = self.sentences[index]
speech_act_class_index = SPEECH_ACTS.index(sentence.label)
return sentence.text, speech_act_class_index
def get_class_frequency(self, class_index):
speech_act = SPEECH_ACTS[class_index]
return self.class_frequencies[speech_act]
class DocumentDataset(tdat.Dataset):
"""
A Pytorch compatible dataset for a speech act labeled Stanza Document.
"""
def __init__(self, document: doc.Document) -> None:
super().__init__()
self.document = document
def __len__(self) -> int:
return len(self.document.sentences)
def __getitem__(self, index) -> tuple[str, int]:
sentence = self.document.sentences[index]
speech_act_class_index = SPEECH_ACTS.index(sentence.speech_act)
return sentence.text, speech_act_class_index
def batched(self, batch_size: int) -> Generator[list[doc.Sentence], None, None]:
batch = []
for sentence in self.document.sentences:
batch.append(sentence)
if len(batch) == batch_size:
yield batch
batch = []
def linear_perceptron(input_size: int, output_size: int) -> nn.Module:
return nn.Linear(input_size, output_size)
def softmax_perceptron(input_size: int, output_size: int) -> nn.Module:
return nn.Sequential(
nn.Linear(input_size, output_size),
nn.Softmax(dim=1)
)
def sigmoid_hidden_layer(input_size: int, output_size: int) -> nn.Module:
hidden_size = 64
return nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.Sigmoid(),
nn.Linear(hidden_size, output_size),
nn.Softmax(dim=1)
)
class EmbeddingClassifier (base.Classifier):
"""
Classifies sentences based on their embeddings. The sentence embeddings are computed using a Swedish
SBERT model.
Args:
device: The name of the device to run the models on.
network_factory: a callable function that creates the classification network.
"""
def __init__(self, device='mps', network_factory: NetworkFactory|None = None) -> None: # mps is the macbook's GPU.
super().__init__()
self.device = device
# Load embedding model.
self.emb_model = stf.SentenceTransformer('KBLab/sentence-bert-swedish-cased', device=device)
# Create the neural network.
input_size: int = self.emb_model.get_sentence_embedding_dimension() # type: ignore
output_size = len(SPEECH_ACTS)
# Create the network from the factory, or use default linear perceptron.
if network_factory != None:
self.cls_model = network_factory(input_size, output_size)
else:
self.cls_model = linear_perceptron(input_size, output_size)
# Run on device.
self.cls_model = self.cls_model.to(device)
def classify_document(self, document: doc.Document):
doc_dataset = DocumentDataset(document)
self.cls_model.eval()
with torch.no_grad():
for batch in doc_dataset.batched(32):
texts = [sent.text for sent in batch]
embeddings = self.emb_model.encode(texts, convert_to_numpy=False, # type: ignore
convert_to_tensor=True)
outputs = self.cls_model(embeddings)
# Assign outputs to sentences.
for output, sentence in zip(outputs, batch):
class_index = torch.argmax(output)
speech_act = SPEECH_ACTS[class_index]
sentence.speech_act = speech_act # type: ignore
def classify_sentence(self, sentence: doc.Sentence):
speech_act = self.get_speech_act_for(sentence)
sentence.speech_act = speech_act # type: ignore
def get_speech_act_for(self, sentence: doc.Sentence|str) -> anno.SpeechActLabels:
"""
Classify the speech act of the sentence. This only returns the speech act, and
does not assign it to the 'speech_act' property of the sentence instance.
"""
# Get sentence text from input.
if isinstance(sentence, doc.Sentence):
assert sentence.text != None, f'sentence.text == None for {sentence.sent_id}'
text = sentence.text
else:
text = sentence
self.cls_model.eval()
# Create embedding and classify.
embedding = self.emb_model.encode(text, convert_to_numpy=False)
with torch.no_grad():
output = self.cls_model.forward(embedding) # type: ignore
class_index = torch.argmax(output)
return SPEECH_ACTS[class_index]
def train(self, data: CorpusDataset, batch_size: int, num_epochs = 10,
save_each_epoch: None|str = None, use_class_weights=False,
loss_history: list[float]|None = None, dev_loss_history: list[float]|None = None,
dev_data: CorpusDataset|None = None,
callback_each_batch=-1,
batch_callback:Callable[[int], None]|None = None):
"""
Train the classifier on labeled embeddings from an corpus.
"""
import tqdm # For progress bar.
optimizer = optim.Adam(self.cls_model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss().to(self.device)
# Use class weights.
if use_class_weights:
class_weights = [1.0 / data.get_class_frequency(i) for i in range(len(SPEECH_ACTS))]
criterion.weight = torch.tensor(class_weights, dtype=torch.float32).to(self.device)
# Handle training data.
train_loader = tdat.DataLoader(data, batch_size=batch_size, shuffle=True)
# Handle dev data.
if dev_data != None:
dev_loader = tdat.DataLoader(dev_data, batch_size=batch_size, shuffle=True)
# Train the network.
batch_count = 0
for epoch in range(num_epochs):
self.cls_model.train()
running_loss = 0.0
for inputs, labels in tqdm.tqdm(train_loader, desc=f'Training: epoch {epoch+1}/{num_epochs}", unit="batch'):
labels = labels.to(self.device)
optimizer.zero_grad()
# Do forward pass.
embeddings = self.emb_model.encode(inputs,
convert_to_numpy=False,
convert_to_tensor=True)
outputs = self.cls_model(embeddings)
# Compute loss and backpropagate.
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
# Do callback.
if (batch_callback != None and callback_each_batch != -1 and
batch_count % callback_each_batch == 0):
batch_callback(batch_count)
batch_count += 1
# Save model.
if save_each_epoch != None:
self.save(save_each_epoch)
# Calculate average loss for the epoch
epoch_loss = running_loss / len(train_loader)
print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}')
# Save the loss to history.
if loss_history != None:
loss_history.append(epoch_loss)
# Compute loss on dev data.
if dev_data != None:
running_dev_loss = 0.0
self.cls_model.eval()
for inputs, labels in tqdm.tqdm(dev_loader, desc=f'Eval on dev data: epoch {epoch+1}/{num_epochs}", unit="batch'):
labels = labels.to(self.device)
embeddings = self.emb_model.encode(inputs,
convert_to_numpy=False,
convert_to_tensor=True)
outputs = self.cls_model(embeddings)
loss = criterion(outputs, labels)
running_dev_loss += loss.item()
# Calculate average dev loss for the epoch
dev_epoch_loss = running_dev_loss / len(dev_loader)
print(f'Epoch {epoch+1}/{num_epochs}, Dev loss: {dev_epoch_loss}')
# Save the dev loss to history.
if dev_loss_history != None:
dev_loss_history.append(dev_epoch_loss)
print('Training complete')
def save(self, file_name: str):
"""
Save the model to a file. This is only saves the classification network and not the
embedding model.
"""
print(f'Saving model to "{file_name}"')
torch.save(self.cls_model.state_dict(), file_name)
def load(self, file_name: str):
"""
Load the model from a file. This only loads the classification network and not the
embedding model.
"""
print(f'Loading model from "{file_name}"')
self.cls_model.load_state_dict(torch.load(file_name, map_location=self.device))