-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathextract_train_data.py
170 lines (138 loc) · 6.79 KB
/
extract_train_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
'''
Filename: extract_train_data.py
Authors: Luke Rowe, Jing Zhu, Quinton Yong
Date: April 19, 2019
Description: This file extracts and preprocesses the training data used to train the
MIML network. The extracted training data is written to an h5 file.
'''
import csv
import os
import cv2
import time
import h5py
import librosa as li
import numpy as np
import torch
from tqdm import tqdm
from preprocessing import get_frame_labels, get_frames, extract_bases
def find_training_data():
print("Extracting training labels...")
training_data_filename = "unbalanced_train_segments.csv" #use unbalanced youtube videos for testing
# labels violin, piano, guitar, drum (in that order)
instrument_dict = {'drum':'/m/026t6', 'guitar':'/m/042v_gx', 'piano':'/m/05r5c', 'violin':'/m/07y_7'}
drum_data = []
guitar_data = []
piano_data = []
violin_data = []
#load the video info from csv file
with open(training_data_filename, mode='r') as file:
csv_reader = csv.reader(file)
# skip 3 lines of the csv file
next(csv_reader)
next(csv_reader)
next(csv_reader)
for row in (csv_reader):
# process the encoded labels
curr_label_id_list = row[3:len(row)]
#remove the " characters from the label ids
curr_label_id_list[0] = curr_label_id_list[0].replace("\"", "").strip()
curr_label_id_list[-1] = curr_label_id_list[-1].replace("\"", "")
# get all the videos each class for training(4000 samples each)
if instrument_dict['drum'] in curr_label_id_list and len(drum_data) < 4000:
drum_data.append([row[0], row[1].strip()])
elif instrument_dict['guitar'] in curr_label_id_list and len(guitar_data) < 4000:
guitar_data.append([row[0], row[1].strip()])
elif instrument_dict['piano'] in curr_label_id_list and len(piano_data) < 4000:
piano_data.append([row[0], row[1].strip()])
elif instrument_dict['violin'] in curr_label_id_list and len(violin_data) < 4000:
violin_data.append([row[0], row[1].strip()])
print("Number of drum samples:", len(drum_data))
print("Number of guitar samples:", len(guitar_data))
print("Number of piano samples:", len(piano_data))
print("Number of violin samples:", len(violin_data))
return drum_data + guitar_data + piano_data + violin_data
def get_audio_image(tr_data):
num_skipped_videos = 0
#numpy array to hold all W matrices
W_all = np.zeros((len(tr_data), 2401,25))
for count in tqdm(range(len(tr_data))):
sample = tr_data[count]
url = 'https://www.youtube.com/watch?v=' + sample[0]
video_start_time = sample[1]
# Download from local video file
if (url):
os.system("ffmpeg -ss " + str(video_start_time) + " -i $(youtube-dl -i -f 37/22/18 -g \'" + url + "\') -t " + str(
10) + " -c copy video.mp4 >/dev/null 2>&1")
os.system("ffmpeg -i video.mp4 audio.wav >/dev/null 2>&1")
# obtain cv2.VideoCapture obj from downloaded video if success
cap = cv2.VideoCapture("video.mp4")
else:
print("Error in downloading youtube video")
if not os.path.exists("./video.mp4"):
num_skipped_videos += 1
continue
# load audio from file
ts, sr = li.core.load("./audio.wav", sr=48000)
# skip if audio is shorter than 10 seconds
if (len(ts) < 10*sr):
os.remove("./audio.wav")
os.remove("./video.mp4")
print("\n\n\n Sample {} is too short to be processed.".format(sample[0]))
print("Namely, the sample is {} seconds long.\n\n\n".format(len(ts)/sr))
num_skipped_videos += 1
continue
ts = ts[0:10*sr] # cut audio into exact 10 seconds if it's longer than that
all_image_tensors, skip = get_frames(cap) # get all the transformed frames
# skip the current video if error occured during the frame extraction process
if skip:
num_skipped_videos += 1
print("\n\n\nUnable to extract all frames from sample {}\n\n\n".format(sample[0]))
if os.path.exists('./audio.wav'):
os.remove('./audio.wav')
if os.path.exists('./video.mp4'):
os.remove('./video.mp4')
for k in range(skip):
if os.path.exists('frame{}.jpg'.format(k)):
os.remove('frame{}.jpg'.format(k))
continue
max_pool_labels = get_frame_labels(all_image_tensors) # get predicted labels for captured frames
# create the set of basis vectors and object labels for each audio sample
if count == 0:
# call the NMF algorithm
W_all = np.expand_dims(extract_bases(ts),0) # extract audio into audio bases
labels_all = max_pool_labels.detach().unsqueeze(0) # maxpool resnet labels
else:
W = extract_bases(ts) # extract audio into audio bases
W_all = np.concatenate((W_all,np.expand_dims(W,0))) # append audio bases into list
labels_all = torch.cat((labels_all, max_pool_labels.detach().unsqueeze(0)),0) # maxpool resnet labels
# remove all the captured images, downloaded video and audio
for i in range(10):
os.remove('./frame{}.jpg'.format(i))
os.remove('./video.mp4')
os.remove('./audio.wav')
# write data to h5 file every 500 samples in case lose connection
# write audio frequency bases into h5 file and get it ready for training
if (count % 500 == 0):
with h5py.File('./data.h5', 'w') as hdf5:
hdf5.create_dataset('bases', data=W_all)
hdf5.create_dataset('labels', data=labels_all)
# dump all the audio frequency bases into h5 file and get it ready for training
with h5py.File('./data.h5', 'w') as hdf5:
hdf5.create_dataset('bases', data=W_all)
hdf5.create_dataset('labels', data=labels_all)
print("{} samples were skipped.".format(num_skipped_videos))
def main():
#remove wav and mp4 files if they exist from previous run
if os.path.exists('./audio.wav'):
os.remove('./audio.wav')
print("Removing ./audio.wav file from previous run...")
if os.path.exists('./video.mp4'):
os.remove('./video.mp4')
print("Removing ./video.mp4 file from previous run...")
t_start = time.perf_counter()
tr_data = find_training_data()
print("Number of samples:", len(tr_data))
get_audio_image(tr_data)
print('total time:', time.perf_counter() - t_start)
if __name__ == '__main__':
main()