-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.py
32 lines (29 loc) · 1.21 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer
import torch
class Carolina:
"""
Model init method
base model used = bloom-560m
model_path: takes the path of saved model weights
"""
def __init__(self,model_path:str="./") -> None:
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.tokenize = lambda input: tokenizer(input, truncation=True, padding="max_length", max_length=256,return_tensors="pt").to(device)
self.model.to(device)
"""
predict
takes text as input and classifies the text
returns 0 or 1
"""
def predict(self,text:str) -> int:
input = self.tokenize(text)
with torch.no_grad():
output = self.model(**input)
predicted_class = torch.argmax(output.logits, dim=1).item()
return predicted_class
if __name__=="__main__":
model = Carolina()
text = "Hi I am text classifier, will help you in deleteing housing ads message"
print("Ad" if model.predict(text) else "Safe")