diff --git a/nlg.py b/nlg.py index 924a295..3a19231 100644 --- a/nlg.py +++ b/nlg.py @@ -81,6 +81,7 @@ def main(): Starts an infinite loop that can be broken only via Ctrl+C or by typing "exit" as prompt. """ + # TODO: this is an overkill - change that, pickle dict locally _, char2idx = process_corpus() idx2char = {v: k for k, v in char2idx.items()} @@ -98,17 +99,15 @@ def main(): while True: prompt = input("\nUser:\n") - # TODO: CHECK FOR EXIT BEFORE CHECK FOR INPUT LEN + if prompt.strip() == "exit": + print(config.MSG_FAREWELL) + quit() if len(prompt) < config.INPUT_LENGTH: - print(f"Please provide a prompt of {config.INPUT_LENGTH}") - + print(f"\nPlease provide a prompt of {config.INPUT_LENGTH}") # If prompt too short send a shakespearean message print(config.MSG_INPUT_TOO_SHORT.format(config.INPUT_LENGTH)) continue - elif prompt == "exit": - print(config.MSG_FAREWELL) - quit() else: generated_text = generate_text(gpt, prompt, char2idx, idx2char) print(f"\nShakespeare-GPT:\n{generated_text}\n")