-
Notifications
You must be signed in to change notification settings - Fork 2
/
convert-to-coreml
executable file
·59 lines (44 loc) · 2.21 KB
/
convert-to-coreml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
#!/usr/bin/env python3
# A script for converting the Spleeter separator to Core ML.
# Useful resources:
# - https://github.com/deezer/spleeter/issues/210
# - https://github.com/deezer/spleeter/issues/155
# - https://twitter.com/ExtractorVocal/status/1342643493227773952
import argparse
import coremltools as ct
import torch
from pathlib import Path
from spleeter_pytorch.estimator import Estimator
ROOT = Path(__file__).resolve().parent
def main():
parser = argparse.ArgumentParser(description='Converts Spleeter (minus the STFT preprocessing) to Core ML')
parser.add_argument('-n', '--num-instruments', type=int, default=2, help='The number of stems.')
parser.add_argument('-m', '--model', type=Path, default=ROOT / 'checkpoints' / '2stems' / 'model', help='The path to the model to use.')
parser.add_argument('-o', '--output', type=Path, default=ROOT / 'output' / 'coreml', help='The output directory to place the model in')
parser.add_argument('-l', '--length', type=float, default=5.0, help='The input length in seconds for the converted Core ML model (which will only take fixed-length inputs). Default: 5 seconds')
args = parser.parse_args()
samplerate = 44100
estimator = Estimator(num_instruments=args.num_instruments, checkpoint_path=args.model)
estimator.eval()
# Create sample 'audio' for tracing
wav = torch.zeros(2, int(args.length * samplerate))
# Reproduce the STFT step (which we cannot convert to Core ML, unfortunately)
_, stft_mag = estimator.compute_stft(wav)
print('==> Tracing model')
traced_model = torch.jit.trace(estimator.separator, stft_mag)
out = traced_model(stft_mag)
print('==> Converting to Core ML')
mlmodel = ct.convert(
traced_model,
convert_to='mlprogram',
# TODO: Investigate whether we'd want to make the input shape flexible
# See https://coremltools.readme.io/docs/flexible-inputs
inputs=[ct.TensorType(shape=stft_mag.shape)]
)
output_dir: Path = args.output
output_dir.mkdir(parents=True, exist_ok=True)
output = output_dir / f'Spleeter-{args.num_instruments}stems.mlpackage'
print(f'==> Writing {output}')
mlmodel.save(output)
if __name__ == '__main__':
main()