Skip to content

Commit

Permalink
Export Mimi model to ExecuTorch
Browse files Browse the repository at this point in the history
  • Loading branch information
Martin Yuan committed Feb 27, 2025
1 parent afcec1d commit 6666d8a
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 0 deletions.
15 changes: 15 additions & 0 deletions examples/models/moshi/mimi/install_requirements.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

set -x

pip install -U moshi
pip install bitsandbytes
# Run llama2/install requirements for torchao deps
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

bash "$SCRIPT_DIR"/../llama/install_requirements.sh
145 changes: 145 additions & 0 deletions examples/models/moshi/mimi/mimi_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright (c) Kyutai, all rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import random
import time

from huggingface_hub import hf_hub_download
import numpy as np
import sphn
import torch
from torch.profiler import profile, ProfilerActivity

from moshi.models import loaders

import torch.nn as nn

from executorch.examples.models.llama.llama_transformer import Transformer

from executorch.examples.models.llama.model_args import ModelArgs

from torch.export import export, export_for_training, ExportedProgram

from executorch.exir import (
EdgeCompileConfig,
ExecutorchBackendConfig,
to_edge_transform_and_lower,
)
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner

parser = argparse.ArgumentParser()
parser.add_argument("--mimi-weight", type=str)
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO)
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu"
)
parser.add_argument("--profile", action="store_true")
args = parser.parse_args()


def seed_all(seed):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


seed_all(42424242)


print("loading mimi")
if args.mimi_weight is None:
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
mimi = loaders.get_mimi(args.mimi_weight, args.device)
print("mimi loaded")
# emb = torch.load('emb.pt')

def mimi_test(mimi, max_duration_sec=10.0):
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
sample_rate = mimi.sample_rate
# Uncomment below to get real audio
# # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
# sample_pcm, sample_sr = sphn.read("/Users/myuan/src/moshi0/src/moshi/data/bria-24khz.mp3")
# print("loaded pcm", sample_pcm.shape, sample_sr)
# sample_pcm = sphn.resample(
# sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
# )
# sample_pcm = torch.tensor(sample_pcm, device=args.device)
# max_duration_len = int(sample_rate * max_duration_sec)
# if sample_pcm.shape[-1] > max_duration_len:
# sample_pcm = sample_pcm[..., :max_duration_len]
# print("resampled pcm", sample_pcm.shape, sample_sr)
# sample_pcm = sample_pcm[None].to(device=args.device)
#
sample_pcm = torch.ones(1,1,240000)

print("streaming encoding...")
start_time = time.time()
all_codes = []

def run_loop():
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
chunk = sample_pcm[..., start_idx:end_idx]
codes = mimi.encode(chunk)
if codes.shape[-1]:
print(start_idx, codes.shape, end="\r")
all_codes.append(codes)

if args.profile:
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
run_loop()
prof.export_chrome_trace("trace.json")
else:
run_loop()
all_codes_th = torch.cat(all_codes, dim=-1)
print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s")
print("streaming decoding...")
all_pcms = []
# with mimi.streaming(1):
# for i in range(all_codes_th.shape[-1]):
# codes = all_codes_th[..., i : i + 1]
# pcm = mimi.decode(codes)
# print(i, pcm.shape, end="\r")
# all_pcms.append(pcm)
# all_pcms = torch.cat(all_pcms, dim=-1)
# print("pcm", all_pcms.shape, all_pcms.dtype)
# sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate)
pcm_ref = mimi.decode(all_codes_th)

class MimiDecode(nn.Module):
def __init__(self, mimi: nn.Module):
super().__init__()
self.mimi_model = mimi

def forward(self, x):
return self.mimi_model.decode(x)

mimi_decode = MimiDecode(mimi)

ep: ExportedProgram = torch.export.export(mimi_decode, (all_codes_th,), strict=False).module()
edge_prog = to_edge_transform_and_lower(
ep,
partitioner=[XnnpackPartitioner()],
)
class MimiEncode(nn.Module):
def __init__(self, mimi: nn.Module):
super().__init__()
self.mimi_model = mimi

def forward(self, x):
return self.mimi_model.encode(x)

mimi_encode = MimiEncode(mimi)
chunk = sample_pcm[..., 0:pcm_chunk_size]
out = mimi_encode(chunk)
exported_encode = torch.export.export(mimi_encode, (chunk,), strict=False).module()

with torch.no_grad():
mimi_test(mimi)

0 comments on commit 6666d8a

Please sign in to comment.