-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathprompt_level.py
140 lines (117 loc) · 5.71 KB
/
prompt_level.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import pandas as pd
import numpy as np
import argparse
import itertools
import matplotlib.dates as dates
from ast import literal_eval
import seaborn as sns
from scipy import stats
from scipy.optimize import curve_fit
from collections import Counter, OrderedDict
import matplotlib.pyplot as plt
from utils import *
def prompt_stats(dataset_name, data_path):
df = pd.read_csv(data_path)
print('Number of prompts: {}. Number of unique prompts: {}'.format(len(df), len(set(df['prompt'].tolist()))))
if dataset_name in ['midjourney', 'diffusiondb']:
print('Number of users: {}'.format(len(set(df['author_id'].tolist()))))
def plot_prompt_freq(dataset_name, data_path, save_root, num=100):
tokens_list = read_tokens_from_file(data_path)
prompts = [' '.join(tokens) for tokens in tokens_list]
freq_dict = dict(OrderedDict(sorted(dict(Counter(prompts)).items(), key=lambda t: t[1], reverse=True)))
freq = list(freq_dict.values())
plt.figure(figsize=(10,5))
plt.loglog(np.arange(len(freq))+1, freq, label='Raw data', linewidth=3)
popt, pcov = curve_fit(power_law, np.arange(len(freq))+1, freq, p0=[1, 1], bounds=[[1e-3, 1e-3], [1e20, 50]])
plt.plot(np.arange(len(freq))+1, power_law(np.arange(len(freq))+1, *popt), label='Zipf\'s law fit', linestyle = "dashed",color="tab:blue", alpha=0.5, linewidth=3)
plt.legend()
plt.grid(alpha=0.5)
plt.xlabel('Prompt ranked by frequency (log-scale)')
plt.ylabel('Prompt frequency (log-scale)')
df = pd.DataFrame({'prompt': list(freq_dict.keys()),
'freq': list(freq_dict.values())})
df.head(num).to_csv(os.path.join(save_root, '{}_top_prompts.csv'.format(dataset_name)), index=False)
return os.path.join(save_root, '{}_top_prompts.csv'.format(dataset_name))
def plot_timestamp_hist_24h(dataset_name, data_path):
df = pd.read_csv(data_path)
timestamp_list = df['timestamp'].tolist()
fig, ax = plt.subplots(figsize=(15,6))
if dataset_name == 'midjourney':
all_times = ['2000-01-01T'+t.split(':')[0].split('T')[1] for t in timestamp_list]
elif dataset_name == 'diffusiondb':
all_times = ['2000-01-01T'+t.split(':')[0].split(' ')[1] for t in timestamp_list]
times_freq = dict(Counter(all_times))
times_freq = OrderedDict(sorted(times_freq.items(), key=lambda t: t[0]))
freq_density = np.array(list(times_freq.values())) / np.array(list(times_freq.values())).sum()
timestamps = np.array(list(times_freq.keys()), dtype='datetime64')
plt.plot(timestamps, freq_density, label=dataset_name)
ax.xaxis.set_major_locator(dates.HourLocator())
ax.xaxis.set_major_formatter(dates.DateFormatter('%H:%M'))
xticks = plt.gca().xaxis.get_major_ticks()
xticks[0].set_visible(False)
xticks[-1].set_visible(False)
ax.set_xlabel('Time (hour)')
ax.set_ylabel('Proportion of prompts')
plt.setp(ax.xaxis.get_majorticklabels(), rotation=90)
plt.legend()
plt.grid(alpha=0.5)
def plot_rating_hist(sac_path):
sac_file = pd.read_csv(sac_path)
merged_ratings = [list(itertools.chain.from_iterable(literal_eval(r))) for r in sac_file['rating'].tolist()]
all_ratings = list(itertools.chain.from_iterable(merged_ratings))
fig = plt.figure(figsize=(10,3))
ax = fig.add_subplot(1, 1, 1)
ax.margins(x=0)
ax.hist(all_ratings, bins=np.arange(1, 12)-0.3, width=0.6, density=True, color='tab:blue')
major_ticks = [0, 0.05, 0.1, 0.15]
minor_ticks = 0.025 * np.arange(7)
ax.set_yticks(major_ticks)
ax.set_yticks(minor_ticks, minor=True)
ax.grid(alpha=0.5, axis='y', which='both')
ax.set_xlabel("Rating")
ax.set_ylabel("Density")
ax.set_xlim((0.5, 10.5))
ax.set_xticks(np.arange(1, 10.5))
def plot_rating_vs_length(sac_path):
sac_file = pd.read_csv(sac_path)
lengths = []
ratings = []
prompts = []
for index, row in sac_file.iterrows():
rs = list(itertools.chain.from_iterable(literal_eval(row['rating'])))
if len(rs) > 0:
for r in rs:
prompts.append(row['prompt'])
lengths.append(len(literal_eval(row['tokenized'])))
ratings.append(r)
plt.figure(figsize=(10, 6))
s = sns.regplot(x=lengths, y=ratings, x_bins=50, color='tab:blue')
s.set_xticklabels(s.get_xticks())
s.set_yticklabels(s.get_yticks())
plt.grid(alpha=0.5)
plt.xlabel('Prompt length')
plt.ylabel('Rating')
pearson_corr, _ = stats.pearsonr(lengths, ratings)
return pearson_corr
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_name', required=True, type=str, choices=['midjourney', 'diffusiondb', 'sac', 'laion'])
parser.add_argument('--data_path', required=True, type=str)
parser.add_argument('--save_root', required=True, type=str)
parser.add_argument('--topk_prompts', type=int, default=100)
args = parser.parse_args()
print('Basic statistics of {} (prompt-level):'.format(args.dataset_name))
prompt_stats(args.dataset_name, args.data_path)
print('Plot prompt frequency:')
prompt_freq_path = plot_prompt_freq(args.dataset_name, args.data_path, args.save_root, args.topk_prompts)
print('Results of prompts with frequency saved to {}.'.format(prompt_freq_path))
if args.dataset_name == 'sac':
print('Plot histogram of ratings (SAC):')
plot_rating_hist(args.data_path)
print('Plot of ratings and prompt lengths (SAC):')
pearson_corr = plot_rating_vs_length(args.data_path)
print('The Pearson correlation coefficient is {}.'.format(pearson_corr))
elif args.dataset_name in ['midjourney', 'diffusiondb']:
print('Plot histogram of timestamps (per 24h):')
plot_timestamp_hist_24h(args.dataset_name, args.data_path)