generated from StartBootstrap/startbootstrap-agency
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathnlu_model.py
33 lines (25 loc) · 1.1 KB
/
nlu_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import logging
import pprint
from rasa_nlu.training_data import load_data
from rasa_nlu import config
from rasa_nlu.model import Trainer
from rasa_nlu.model import Interpreter
from rasa_nlu.evaluate import run_evaluation
logfile = 'nlu_model.log'
def train_nlu(data_path, configs, model_path):
logging.basicConfig(filename=logfile, level=logging.DEBUG)
training_data = load_data(data_path)
trainer = Trainer(config.load(configs))
trainer.train(training_data)
model_directory = trainer.persist(model_path, project_name='current', fixed_model_name='nlu')
run_evaluation(data_path, model_directory)
def run_nlu(nlu_path):
logging.basicConfig(filename=logfile, level=logging.DEBUG)
interpreter = Interpreter.load(nlu_path)
pprint.pprint(interpreter.parse("What do I do when I'm sad?"))
pprint.pprint(interpreter.parse("What do I do when I'm happy"))
pprint.pprint(interpreter.parse("What do I do when I'm stressed?"))
if __name__ == '__main__':
# train_nlu('./data/nlu.md', 'nlu_config.yml', './models')
run_nlu('./models/current/nlu')
No commit comments for this range