Skip to content

Commit

Permalink
Rewrite the model and the inference code
Browse files Browse the repository at this point in the history
WIP (#3)

* Inference does now wotk with the estimator API.
* Pre-processing and multi-threading are currently a bit hacky.
  • Loading branch information
yweweler committed Nov 29, 2018
1 parent 84709c0 commit 3623569
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 80 deletions.
47 changes: 23 additions & 24 deletions tacotron/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ def main(_):
assert os.path.exists(checkpoint_file) is False, \
'The requested checkpoint file "{}" does not exist.'.format(checkpoint_file)

print('Loading checkpoint from "{}"'.format(checkpoint_file))

# Create a dataset loader.
dataset = dataset_params.dataset_loader(dataset_folder=dataset_params.dataset_folder,
char_dict=dataset_params.vocabulary_dict,
Expand All @@ -159,13 +161,15 @@ def main(_):

sentences = py_pre_process_sentences(raw_sentences, dataset)

def sentence_generator(_sentences):
def __build_sentence_generator(*args):
_sentences = args[0]
for s in _sentences:
yield s

input_fn = inference_input_fn(
dataset_loader=dataset,
sentence_generator=sentence_generator(sentences)
sentence_generator=__build_sentence_generator,
sentences=sentences
)

estimator = tf.estimator.Estimator(
Expand All @@ -176,34 +180,29 @@ def sentence_generator(_sentences):
)

# Start prediction.
print('calling: estimator.predict')
predict_result = estimator.predict(input_fn=input_fn,
hooks=None,
predict_keys=['output_linear_spec'],
checkpoint_path=checkpoint_file)

print('Prediction result: {}'.format(predict_result))

# # Create batched placeholders for inference.
# placeholders = Tacotron.model_placeholders()
#
# # Create the Tacotron model.
# tacotron_model = Tacotron(inputs=placeholders, mode=Mode.PREDICT)
#
# # generate linear scale magnitude spectrograms.
# specs = inference(tacotron_model, sentences)
#
# wavs = py_post_process_spectrograms(specs)
#
# # Write all generated waveforms to disk.
# for i, (sentence, wav) in enumerate(zip(raw_sentences, wavs)):
# # Append ".wav" to the sentence line number to get the filename.
# file_name = '{}.wav'.format(i + 1)
#
# # Generate the full path under which to save the wav.
# save_path = os.path.join(inference_params.synthesis_dir, file_name)
#
# # Write the wav to disk.
# save_wav(save_path, wav, model_params.sampling_rate, True)
# print('Saved: "{}"'.format(save_path))
# Write all generated waveforms to disk.
for i, (sentence, result) in enumerate(zip(raw_sentences, predict_result)):
spectrogram = result['output_linear_spec']
wavs = py_post_process_spectrograms([spectrogram])
wav = wavs[0]

# Append ".wav" to the sentence line number to get the filename.
file_name = '{}.wav'.format(i + 1)

# Generate the full path under which to save the wav.
save_path = os.path.join(inference_params.synthesis_dir, file_name)

# Write the wav to disk.
save_wav(save_path, wav, model_params.sampling_rate, True)
print('Saved: "{}"'.format(save_path))


if __name__ == '__main__':
Expand Down
63 changes: 41 additions & 22 deletions tacotron/input/functions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import tensorflow as tf
from tensorflow.python.data.experimental.ops import grouping

import numpy as np

from tacotron.input.helpers import derive_bucket_boundaries
from tacotron.params.evaluation import evaluation_params
from tacotron.params.model import model_params
Expand Down Expand Up @@ -38,38 +40,55 @@ def eval_input_fn(dataset_loader):
model_n_fft=model_params.n_fft)


def inference_input_fn(dataset_loader, sentence_generator):
def inference_input_fn(dataset_loader, sentence_generator, sentences):
print('called: inference_input_fn')
return __build_inference_input_fn(dataset_loader=dataset_loader,
sentence_generator=sentence_generator,
n_threads=inference_params.n_threads)
sentences=sentences,
n_threads=inference_params.n_synthesis_threads)


# TODO: Debug and implement sentence pre-processing.
def __build_inference_input_fn(dataset_loader, sentence_generator, n_threads):
dataset = tf.data.Dataset.from_generator(sentence_generator,
(tf.int32),
(tf.TensorShape([None, 1])))

def __element_pre_process_fn(sentence):
processed_tensors = (
tf.decode_raw(sentence, tf.int32),
)
return processed_tensors
def __build_inference_input_fn(dataset_loader, sentence_generator, sentences, n_threads):
print('called: __build_inference_input_fn')

# Pre-process dataset elements.
dataset = dataset.map(__element_pre_process_fn, num_parallel_calls=n_threads)
def __input_fn():
def __const_generator():
yield np.array([0] * 32, dtype=np.int32)

dataset = tf.data.Dataset.from_generator(__const_generator,
(tf.int32),
(tf.TensorShape([None, ])))

# dataset = tf.data.Dataset.from_generator(sentence_generator,
# (tf.int32),
# (tf.TensorShape([None, 1])),
# args=[sentences])
#
# def __element_pre_process_fn(sentence):
# processed_tensors = (
# tf.decode_raw(sentence, tf.int32),
# )
# return processed_tensors
#
# # Pre-process dataset elements.
# # dataset = dataset.map(__element_pre_process_fn, num_parallel_calls=n_threads)

dataset = dataset.batch(1)

# Create an iterator over the dataset.
iterator = dataset.make_one_shot_iterator()
# Create an iterator over the dataset.
iterator = dataset.make_one_shot_iterator()

# Get features from the iterator.
ph_sentences = iterator.__next__()
# Get features from the iterator.
ph_sentences = iterator.get_next()

features = {
'ph_sentences': ph_sentences
}
features = {
'ph_sentences': ph_sentences
}

return features, None
return features, None

return __input_fn


def __build_input_fn(dataset_loader, max_samples, batch_size, n_epochs, n_threads,
Expand Down
33 changes: 17 additions & 16 deletions tacotron/input/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,22 @@ def py_pre_process_sentences(_sentences, dataset):
return sentences


def __py_synthesize(linear_mag):
# linear_mag = np.squeeze(linear_mag, -1)
linear_mag = np.power(linear_mag, model_params.magnitude_power)

win_len = ms_to_samples(model_params.win_len, model_params.sampling_rate)
win_hop = ms_to_samples(model_params.win_hop, model_params.sampling_rate)
n_fft = model_params.n_fft

print('Spectrogram inversion ...')
return spectrogram_to_wav(linear_mag,
win_len,
win_hop,
n_fft,
model_params.reconstruction_iterations)


def py_post_process_spectrograms(_spectrograms):
# Apply Griffin-Lim to all spectrogram's to get the waveforms.
normalized = list()
Expand All @@ -48,24 +64,9 @@ def py_post_process_spectrograms(_spectrograms):

specs = normalized

win_len = ms_to_samples(model_params.win_len, model_params.sampling_rate)
win_hop = ms_to_samples(model_params.win_hop, model_params.sampling_rate)
n_fft = model_params.n_fft

def synthesize(linear_mag):
linear_mag = np.squeeze(linear_mag, -1)
linear_mag = np.power(linear_mag, model_params.magnitude_power)

print('Spectrogram inversion ...')
return spectrogram_to_wav(linear_mag,
win_len,
win_hop,
n_fft,
model_params.reconstruction_iterations)

# Synthesize waveforms from the spectrograms.
pool = ThreadPool(inference_params.n_synthesis_threads)
wavs = pool.map(synthesize, specs)
wavs = pool.map(__py_synthesize, specs)
pool.close()
pool.join()

Expand Down
39 changes: 21 additions & 18 deletions tacotron/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,10 +363,11 @@ def model_fn(self, features, labels, mode, params):
"""
# Get the placeholders for the input data.
self.inp_sentences = features['ph_sentences']
self.seq_lengths = features['ph_sentence_lengths']
self.inp_mel_spec = features['ph_mel_specs']
self.inp_linear_spec = features['ph_lin_specs']
self.inp_time_steps = features['ph_time_frames']
if mode != tf.estimator.ModeKeys.PREDICT:
self.seq_lengths = features['ph_sentence_lengths']
self.inp_mel_spec = features['ph_mel_specs']
self.inp_linear_spec = features['ph_lin_specs']
self.inp_time_steps = features['ph_time_frames']

# inp_sentences.shape = (B, T_sent, ?)
batch_size = tf.shape(self.inp_sentences)[0]
Expand Down Expand Up @@ -402,12 +403,13 @@ def model_fn(self, features, labels, mode, params):
# shape => (B, T_spec, (1 + n_fft // 2))
self.output_linear_spec = outputs

inp_mel_spec = self.inp_mel_spec
inp_linear_spec = self.inp_linear_spec
if mode != tf.estimator.ModeKeys.PREDICT:
inp_mel_spec = self.inp_mel_spec
inp_linear_spec = self.inp_linear_spec

inp_mel_spec = tf.reshape(inp_mel_spec, [batch_size, -1, self.hparams.n_mels])
inp_linear_spec = tf.reshape(inp_linear_spec,
[batch_size, -1, (1 + self.hparams.n_fft // 2)])
inp_mel_spec = tf.reshape(inp_mel_spec, [batch_size, -1, self.hparams.n_mels])
inp_linear_spec = tf.reshape(inp_linear_spec,
[batch_size, -1, (1 + self.hparams.n_fft // 2)])

output_mel_spec = self.output_mel_spec
output_linear_spec = self.output_linear_spec
Expand All @@ -432,21 +434,21 @@ def model_fn(self, features, labels, mode, params):
tf.summary.image('linear_spec_gt_loss', linear_spec_image, max_outputs=1)
# ======================================================================================

# Calculate decoder Mel. spectrogram loss.
self.loss_op_decoder = tf.reduce_mean(
tf.abs(inp_mel_spec - output_mel_spec))
if mode != tf.estimator.ModeKeys.PREDICT:
# Calculate decoder Mel. spectrogram loss.
self.loss_op_decoder = tf.reduce_mean(
tf.abs(inp_mel_spec - output_mel_spec))

# Calculate post-processing linear spectrogram loss.
self.loss_op_post_processing = tf.reduce_mean(
tf.abs(inp_linear_spec - output_linear_spec))
# Calculate post-processing linear spectrogram loss.
self.loss_op_post_processing = tf.reduce_mean(
tf.abs(inp_linear_spec - output_linear_spec))

# Combine the decoder and the post-processing losses.
self.loss_op = self.loss_op_decoder + self.loss_op_post_processing
# Combine the decoder and the post-processing losses.
self.loss_op = self.loss_op_decoder + self.loss_op_post_processing

summary_op = self.summary(mode)

if self.is_training(mode):

# NOTE: The global step has to be created before the optimizer is created.
global_step = tf.train.get_global_step()

Expand Down Expand Up @@ -522,6 +524,7 @@ def model_fn(self, features, labels, mode, params):
# evaluation_hooks=[summary_hook]
)
elif mode == tf.estimator.ModeKeys.PREDICT:
print('Model was build in inference mode.')
# Dictionary that is returned on `estimator.predict`.
predictions = {
"output_mel_spec": self.output_mel_spec,
Expand Down

0 comments on commit 3623569

Please sign in to comment.