Skip to content

Commit aef6afb

Browse files
committed
train bert on NER
1 parent 45b178c commit aef6afb

File tree

11 files changed

+79
-33
lines changed

11 files changed

+79
-33
lines changed

.gitattributes

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
*.arrow filter=lfs diff=lfs merge=lfs -text

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
.env
2-
.idea
2+
.idea
3+
__pycache__/
Binary file not shown.
Binary file not shown.
Binary file not shown.
-7.28 MB
Binary file not shown.
Binary file not shown.
-2.03 MB
Binary file not shown.

main.py

+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import torch
2+
import torch.nn as nn
3+
from datasets import load_from_disk
4+
from torch.utils.data import DataLoader
5+
from utils.tokenizer import NERTokenizer
6+
from utils.train import train_ner
7+
from models.base_model import BertNerd
8+
9+
config = {
10+
'PADDING_TOKEN': -100,
11+
'LEARNING_RATE': 0.001,
12+
'NUM_EPOCHS': 10,
13+
'BATCH_SIZE': 16,
14+
'RANDOM_SEED': 42,
15+
'CHUNK_SIZE': 100,
16+
'HIDDEN_SIZE': 768
17+
}
18+
19+
# Load datasets
20+
kaznerd_train = load_from_disk('datasets/kaznerd-train.hf')
21+
kaznerd_test = load_from_disk('datasets/kaznerd-test.hf')
22+
23+
kz_labels_list = kaznerd_train.features["ner_tags"].feature.names
24+
config['NUM_CLASSES'] = len(kz_labels_list)
25+
config['DEVICE'] = 'cuda' if torch.cuda.is_available() else 'cpu'
26+
27+
# Initialize tokenizer
28+
tokenizer = NERTokenizer("bert-base-uncased")
29+
30+
# Tokenize and create dataloaders for Kazakh NER dataset
31+
kz_tokenized_train = kaznerd_train.map(lambda e: tokenizer.tokenize_and_align_labels(e, tags='ner_tags'), batched=True, batch_size=config['BATCH_SIZE'])
32+
kz_tokenized_train.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])
33+
34+
kz_train_dataloader = DataLoader(kz_tokenized_train, batch_size=config['BATCH_SIZE'])
35+
36+
# Define model, loss function, optimizer
37+
kaznerd_model = BertNerd(config)
38+
loss_func = nn.CrossEntropyLoss(ignore_index=config['PADDING_TOKEN'])
39+
optimizer = torch.optim.Adam(kaznerd_model.get_params(), lr=config['LEARNING_RATE'])
40+
41+
train_ner(model=kaznerd_model, optimizer=optimizer, loss_func=loss_func, train_dataloader=kz_train_dataloader,
42+
config=config)

models/base_model.py

+22-25
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,23 @@
11
import torch
2-
from transformers import AutoModelForMaskedLM
2+
from transformers import BertModel
33

44
class BertNerd(torch.nn.Module):
55
"""
66
MBert-based model for performing NER tasks w/ and w/o
77
soft prompts on Kazakh and Turkish languages.
88
"""
99

10-
def __init__(self, config, device, freeze=True):
10+
def __init__(self, config, freeze=True):
1111
super(BertNerd, self).__init__()
12-
self.mbert = AutoModelForMaskedLM("google-bert/bert-base-multilingual-cased")
13-
self.linear = torch.nn.Linear(config.hidden_size, config.num_classes)
14-
self.device = device
12+
self.device = config['DEVICE']
13+
self.mbert = BertModel.from_pretrained("google-bert/bert-base-multilingual-uncased").to(self.device)
14+
self.linear = torch.nn.Linear(config['HIDDEN_SIZE'], config['NUM_CLASSES'])
1515

1616
if freeze:
1717
self.freeze_params()
1818

19+
print("\tModel initilized.")
20+
1921
def forward(self, input_seq, attention_mask):
2022
"""
2123
Define the model's forward pass.
@@ -24,37 +26,32 @@ def forward(self, input_seq, attention_mask):
2426
:param attention_mask: attention mask
2527
:return: predicted logits
2628
"""
27-
2829
input_seq = self.mbert(input_seq, attention_mask).last_hidden_state.to(self.device)
2930
logits = self.linear(input_seq)
3031

3132
return logits
3233

33-
def get_loss(self, loss_fn, logits, labels, ignore_index=None):
34+
def freeze_params(self):
3435
"""
35-
Get loss for the forward pass of the current batch.
36+
Only train the soft prompts, don't train any model parameters.
3637
37-
:param loss_fn: e.g. nn.CrossEntropyLoss
38-
:param logits: predicted labels
39-
:param labels: actual labels from the dataset
40-
:param ignore_index: padding index to ignore
41-
:return: loss per batch
38+
:return: void
4239
"""
4340

44-
loss_func = loss_fn(ignore_index=ignore_index)
45-
46-
# ToDo
47-
# Logits/labels should probably be flattened, so we get the right dimension
48-
49-
return loss_func(logits, labels).detach().item()
50-
41+
for param in self.mbert.parameters():
42+
param.requires_grad = False
5143

52-
def freeze_params(self):
44+
def get_params(self):
5345
"""
54-
Only train the soft prompts, don't train any model parameters.
46+
Return tunable parameters of the model.
5547
56-
:return: void
48+
:return: list of tunable params
5749
"""
5850

59-
for param in self.bert.parameters():
60-
param.requires_grad = False
51+
params = []
52+
53+
for param in self.parameters():
54+
if param.requires_grad:
55+
params.append(param)
56+
57+
return params

utils/train.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,22 @@
11
import torch
2-
from metrics import get_accuracy
2+
from .metrics import get_accuracy
33

44

5-
def train_ner(model, train_dataloader, optimizer, config):
5+
def train_ner(model, train_dataloader, loss_func, optimizer, config):
66
"""
77
Define the training loop for NER.
88
:param model: corresponding model class
9-
:param train_loader: train data
9+
:param train_dataloader: train data
10+
:param loss_func: loss function
11+
:param optimizer: optimizer
1012
:param config: config file with hyperparameters
1113
:return: model, metrics
1214
"""
15+
print("\tTraining started.")
16+
1317
accuracies = []
1418

15-
for epoch in range(config.epochs):
19+
for epoch in range(config['NUM_EPOCHS']):
1620
loss_per_epoch = 0
1721
correct = 0
1822
total = 0
@@ -21,14 +25,15 @@ def train_ner(model, train_dataloader, optimizer, config):
2125
model.train()
2226

2327
for batch in train_dataloader:
24-
inputs, attention_mask, labels = batch["input_ids"].to(config.device), batch["attention_mask"].to(config.device), batch["labels"].to(config.device)
28+
inputs, attention_mask, labels = (batch["input_ids"].to(config['DEVICE']), batch["attention_mask"].to(config['DEVICE']),
29+
batch["labels"].to(config['DEVICE']))
2530

2631
# Make prediction
2732
logits = model(inputs, attention_mask)
2833

2934
# Calculate loss
30-
batch_loss = model.get_loss(logits, labels)
31-
loss_per_epoch += batch_loss
35+
batch_loss = loss_func(logits.flatten(end_dim=1), labels.flatten(end_dim=1))
36+
loss_per_epoch += batch_loss.detach().item()
3237

3338
# Get ids corresponding to the most probably NER tags
3439
tag_ids = torch.max(logits, dim=2).indices

0 commit comments

Comments
 (0)