-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_conversation.py
100 lines (80 loc) · 3.49 KB
/
run_conversation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# Import pipeline, AutoTokenizer, and torch
from transformers import pipeline, AutoTokenizer
import torch
import os
print("-------------------------------------------")
print("Hugging Face Local Inference Example")
print("Task: Dialogue Simulation (via Text Generation)")
print("Model: microsoft/DialoGPT-medium")
print("Note: Using text-generation pipeline with manual history.")
print("-------------------------------------------")
# Define Model name
model_name = "microsoft/DialoGPT-medium"
# --- Model and Tokenizer Loading ---
print("\nLoading model and tokenizer (may download on first run)...")
try:
# Load the text-generation pipeline
generator = pipeline(
"text-generation",
model=model_name,
device=0 if torch.cuda.is_available() else -1
)
# Load the tokenizer separately to access eos_token
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Model and tokenizer loaded successfully.")
if torch.cuda.is_available():
print(f"Running on GPU: {torch.cuda.get_device_name(0)}")
else:
print("Running on CPU.")
except Exception as e:
print(f"Error loading model or tokenizer: {e}")
print("Ensure 'transformers' and 'torch' are installed.")
exit()
# ------------------------------------
# --- Define Conversation Turns ---
user_inputs = [
"Hi, how are you?",
"What's a good plan for a Friday evening in Perth?", # Using context
"That's okay. What do people usually do on Friday evenings?" # Follow-up
]
# Initialize dialogue history string
dialogue_history_string = ""
# --------------------------------
# --- Simulate Conversation ---
print("\n--- Starting Dialogue Simulation ---")
try:
for i, user_text in enumerate(user_inputs):
print(f"\nUser >>> {user_text}")
# Construct the prompt by appending the new user input and EOS token
# The EOS token signals the end of a turn for DialoGPT
prompt = dialogue_history_string + user_text + tokenizer.eos_token
# Generate response using the text-generation pipeline
# We need to specify max_new_tokens to limit the response length
# pad_token_id is often needed to suppress warnings during generation
generated_sequences = generator(
prompt,
max_new_tokens=60, # Adjust max response length as needed
pad_token_id=tokenizer.eos_token_id,
do_sample=True, # Add some randomness
temperature=0.7,
top_k=50
)
# Extract the generated text from the result
full_generated_text = generated_sequences[0]['generated_text']
# Extract *only* the newly generated response part
# Find the end of our prompt in the generated text and take the rest
# Need to be careful here as the model might slightly reformat prompt internally
# A simple way is to take text after the length of the prompt
response_text = full_generated_text[len(prompt):].strip()
# Handle empty responses (e.g., if only EOS token was generated)
if not response_text:
response_text = "(Model generated empty response)"
print(f"Bot >>> {response_text}")
print("--------------------")
# Update the dialogue history string for the next turn
dialogue_history_string = full_generated_text + tokenizer.eos_token
print("\n--- Dialogue Finished ---")
except Exception as e:
print(f"\nError during dialogue generation: {e}")
# ---------------------------
print("\nExample finished.")