-
Notifications
You must be signed in to change notification settings - Fork 3
/
preprocessing.py
49 lines (31 loc) · 1.14 KB
/
preprocessing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import transformers as ts
from datasets import Dataset
from datasets import load_dataset, load_from_disk
import numpy as np
import numpy.core.defchararray as nchar
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from transformers.adapters import AdapterConfig
import math
ds = load_dataset('pubmed', cache_dir="dataset/")
print(ds)
def getTextFromSample(sample):
article = sample["Article"]
title = article["ArticleTitle"].replace("[","").replace("]","").strip()
abstract = article["Abstract"]["AbstractText"].strip()
text = (title + " " + abstract).strip()
return text
tokenizer = ts.AutoTokenizer.from_pretrained("bert-base-cased")
print(tokenizer)
def mappingFunction(dataset):
texts = []
for sample in dataset["MedlineCitation"]:
texts.append(getTextFromSample(sample))
return tokenizer(texts, truncation=True, max_length=256, return_special_tokens_mask=True)
ds["train"] = ds["train"].map(mappingFunction, batched=True)
datasetPath = "tokenizedDatasets/pubmed-256/"
ds.save_to_disk(datasetPath)
print(load_from_disk(datasetPath))