-
Notifications
You must be signed in to change notification settings - Fork 458
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Martin Yuan
committed
Feb 27, 2025
1 parent
afcec1d
commit 6666d8a
Showing
2 changed files
with
160 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/bin/bash | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
set -x | ||
|
||
pip install -U moshi | ||
pip install bitsandbytes | ||
# Run llama2/install requirements for torchao deps | ||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) | ||
|
||
bash "$SCRIPT_DIR"/../llama/install_requirements.sh |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# Copyright (c) Kyutai, all rights reserved. | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import random | ||
import time | ||
|
||
from huggingface_hub import hf_hub_download | ||
import numpy as np | ||
import sphn | ||
import torch | ||
from torch.profiler import profile, ProfilerActivity | ||
|
||
from moshi.models import loaders | ||
|
||
import torch.nn as nn | ||
|
||
from executorch.examples.models.llama.llama_transformer import Transformer | ||
|
||
from executorch.examples.models.llama.model_args import ModelArgs | ||
|
||
from torch.export import export, export_for_training, ExportedProgram | ||
|
||
from executorch.exir import ( | ||
EdgeCompileConfig, | ||
ExecutorchBackendConfig, | ||
to_edge_transform_and_lower, | ||
) | ||
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--mimi-weight", type=str) | ||
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO) | ||
parser.add_argument( | ||
"--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu" | ||
) | ||
parser.add_argument("--profile", action="store_true") | ||
args = parser.parse_args() | ||
|
||
|
||
def seed_all(seed): | ||
torch.manual_seed(seed) | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) # for multi-GPU setups | ||
random.seed(seed) | ||
np.random.seed(seed) | ||
torch.backends.cudnn.deterministic = True | ||
torch.backends.cudnn.benchmark = False | ||
|
||
|
||
seed_all(42424242) | ||
|
||
|
||
print("loading mimi") | ||
if args.mimi_weight is None: | ||
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME) | ||
mimi = loaders.get_mimi(args.mimi_weight, args.device) | ||
print("mimi loaded") | ||
# emb = torch.load('emb.pt') | ||
|
||
def mimi_test(mimi, max_duration_sec=10.0): | ||
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate) | ||
sample_rate = mimi.sample_rate | ||
# Uncomment below to get real audio | ||
# # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3 | ||
# sample_pcm, sample_sr = sphn.read("/Users/myuan/src/moshi0/src/moshi/data/bria-24khz.mp3") | ||
# print("loaded pcm", sample_pcm.shape, sample_sr) | ||
# sample_pcm = sphn.resample( | ||
# sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate | ||
# ) | ||
# sample_pcm = torch.tensor(sample_pcm, device=args.device) | ||
# max_duration_len = int(sample_rate * max_duration_sec) | ||
# if sample_pcm.shape[-1] > max_duration_len: | ||
# sample_pcm = sample_pcm[..., :max_duration_len] | ||
# print("resampled pcm", sample_pcm.shape, sample_sr) | ||
# sample_pcm = sample_pcm[None].to(device=args.device) | ||
# | ||
sample_pcm = torch.ones(1,1,240000) | ||
|
||
print("streaming encoding...") | ||
start_time = time.time() | ||
all_codes = [] | ||
|
||
def run_loop(): | ||
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size): | ||
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size) | ||
chunk = sample_pcm[..., start_idx:end_idx] | ||
codes = mimi.encode(chunk) | ||
if codes.shape[-1]: | ||
print(start_idx, codes.shape, end="\r") | ||
all_codes.append(codes) | ||
|
||
if args.profile: | ||
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: | ||
run_loop() | ||
prof.export_chrome_trace("trace.json") | ||
else: | ||
run_loop() | ||
all_codes_th = torch.cat(all_codes, dim=-1) | ||
print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s") | ||
print("streaming decoding...") | ||
all_pcms = [] | ||
# with mimi.streaming(1): | ||
# for i in range(all_codes_th.shape[-1]): | ||
# codes = all_codes_th[..., i : i + 1] | ||
# pcm = mimi.decode(codes) | ||
# print(i, pcm.shape, end="\r") | ||
# all_pcms.append(pcm) | ||
# all_pcms = torch.cat(all_pcms, dim=-1) | ||
# print("pcm", all_pcms.shape, all_pcms.dtype) | ||
# sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate) | ||
pcm_ref = mimi.decode(all_codes_th) | ||
|
||
class MimiDecode(nn.Module): | ||
def __init__(self, mimi: nn.Module): | ||
super().__init__() | ||
self.mimi_model = mimi | ||
|
||
def forward(self, x): | ||
return self.mimi_model.decode(x) | ||
|
||
mimi_decode = MimiDecode(mimi) | ||
|
||
ep: ExportedProgram = torch.export.export(mimi_decode, (all_codes_th,), strict=False).module() | ||
edge_prog = to_edge_transform_and_lower( | ||
ep, | ||
partitioner=[XnnpackPartitioner()], | ||
) | ||
class MimiEncode(nn.Module): | ||
def __init__(self, mimi: nn.Module): | ||
super().__init__() | ||
self.mimi_model = mimi | ||
|
||
def forward(self, x): | ||
return self.mimi_model.encode(x) | ||
|
||
mimi_encode = MimiEncode(mimi) | ||
chunk = sample_pcm[..., 0:pcm_chunk_size] | ||
out = mimi_encode(chunk) | ||
exported_encode = torch.export.export(mimi_encode, (chunk,), strict=False).module() | ||
|
||
with torch.no_grad(): | ||
mimi_test(mimi) |