-
Notifications
You must be signed in to change notification settings - Fork 0
/
prompt2code.py
100 lines (79 loc) · 3.52 KB
/
prompt2code.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import spacy
import pickle
from tokenize import tokenize, untokenize
from torchtext.data import Field
from colorama import Fore, Style, init
import time
# Mandatory Import (Do not remove) ------> Class which has our Model Structure
from prompt_model import Encoder, EncoderLayer, PositionwiseFeedforwardLayer, MultiHeadAttentionLayer, Decoder, DecoderLayer, Seq2Seq
# Loading the Model and Vocabulary Files
try:
model = torch.load('./models/conversational-ai-model-cpu.pt',
map_location=torch.device('cpu'))
print(f"{Fore.LIGHTGREEN_EX}\n> Model fetched successfully{Style.RESET_ALL}")
with open('./vocabs/source_vocab.pkl', 'rb') as f:
src_vocab = pickle.load(f)
print(f"{Fore.LIGHTGREEN_EX}> Source Vocabulary loaded successfully{Style.RESET_ALL}")
with open('./vocabs/target_vocab.pkl', 'rb') as f:
trg_vocab = pickle.load(f)
print(f"{Fore.LIGHTGREEN_EX}> Target Vocabulary loaded successfully{Style.RESET_ALL}")
except Exception as e:
print(f"{Fore.RED}Error in fetching the model : {e} \n{Style.RESET_ALL}")
# Source (prompt questions) and Target (python codes) Vocabularies
SRC = Field(tokenize=lambda x: x.split(),
init_token='<sos>',
eos_token='<eos>',
lower=True)
TRG = Field(tokenize=lambda x: x.split(),
init_token='<sos>',
eos_token='<eos>',
lower=True)
SRC.vocab = src_vocab
TRG.vocab = trg_vocab
# System GPU is available or not
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Convert the prompt to tokens and tensors, apply the model for code generation
def translate_sentence(sentence, src_field, trg_field, model, device, max_len=50000):
model.eval()
if isinstance(sentence, str):
nlp = spacy.load('en')
tokens = [token.text.lower() for token in nlp(sentence)]
else:
tokens = [token.lower() for token in sentence]
tokens = [src_field.init_token] + tokens + [src_field.eos_token]
src_indexes = [src_field.vocab.stoi[token] for token in tokens]
src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
src_mask = model.make_src_mask(src_tensor)
with torch.no_grad():
enc_src = model.encoder(src_tensor, src_mask)
trg_indexes = [trg_field.vocab.stoi[trg_field.init_token]]
for _ in range(max_len):
trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
trg_mask = model.make_trg_mask(trg_tensor)
with torch.no_grad():
output, attention = model.decoder(
trg_tensor, enc_src, trg_mask, src_mask)
pred_token = output.argmax(2)[:, -1].item()
trg_indexes.append(pred_token)
if pred_token == trg_field.vocab.stoi[trg_field.eos_token]:
break
trg_tokens = [trg_field.vocab.itos[i] for i in trg_indexes]
return trg_tokens[1:], attention
# Function to generate code from user prompt
def eng_to_python(src):
src = src.split(" ")
translation, _ = translate_sentence(src, SRC, TRG, model, device)
return untokenize(translation[:-1]).decode('utf-8')
if __name__ == "__main__":
try:
prompt = input(
f"\n{Fore.CYAN}>>> Enter the prompt to generate code : {Style.RESET_ALL}")
answer = eng_to_python(prompt)
print(f"{Fore.YELLOW}")
for char in answer:
print(char, end="", flush=True)
time.sleep(0.1)
print(f"{Style.RESET_ALL}")
except:
print(f"{Fore.RED}Facing some issues, trying again later. {Style.RESET_ALL}")