Skip to content

Commit

Permalink
Merge pull request #399 from lukewys:update-urmp-dataloader
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 411589575
  • Loading branch information
Magenta Team committed Nov 22, 2021
2 parents b942d2a + bb6aed5 commit 505a6e5
Showing 1 changed file with 19 additions and 4 deletions.
23 changes: 19 additions & 4 deletions ddsp/training/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,11 @@ def features_dict(self):
class Urmp(TFRecordProvider):
"""Urmp training set."""

def __init__(self, base_dir, instrument_key='tpt', split='train'):
def __init__(self,
base_dir,
instrument_key='tpt',
split='train',
suffix=None):
"""URMP dataset for either a specific instrument or all instruments.
Args:
Expand All @@ -420,19 +424,30 @@ def __init__(self, base_dir, instrument_key='tpt', split='train'):
['all', 'bn', 'cl', 'db', 'fl', 'hn', 'ob', 'sax', 'tba', 'tbn',
'tpt', 'va', 'vc', 'vn'].
split: Choices include ['train', 'test'].
suffix: Choices include [None, 'batched', 'unbatched'], but broadly
applies to any suffix adding to the file pattern.
When suffix is not None, will add "_{suffix}" to the file pattern.
This option is used in gs://magentadata/datasets/urmp/urmp_20210324.
With the "batched" suffix, the dataloader will load tfrecords
containing segmented audio samples in 4 seconds. With the "unbatched"
suffix, the dataloader will load tfrecords containing unsegmented
samples which could be used for learning note sequence in URMP dataset.
"""
self.instrument_key = instrument_key
self.split = split
self.base_dir = base_dir
self.suffix = '' if suffix is None else '_' + suffix
super().__init__()

@property
def default_file_pattern(self):
if self.instrument_key == 'all':
file_pattern = 'all_instruments_{}.tfrecord*'.format(self.split)
file_pattern = 'all_instruments_{}{}.tfrecord*'.format(
self.split, self.suffix)
else:
file_pattern = 'urmp_{}_solo_ddsp_conditioning_{}.tfrecord*'.format(
self.instrument_key, self.split)
file_pattern = 'urmp_{}_solo_ddsp_conditioning_{}{}.tfrecord*'.format(
self.instrument_key, self.split, self.suffix)

return os.path.join(self.base_dir, file_pattern)

Expand Down

0 comments on commit 505a6e5

Please sign in to comment.