diff --git a/CMakeLists.txt b/CMakeLists.txt index 983dd7041..15c4f3cdf 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,6 +22,11 @@ option(BUILD_TESTS "Compile the tests" OFF) option(BUILD_SHARED_LIBS "Build shared libraries" ON) option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF) option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF) +option(MOONSHINE "Compile with moonshine specializations" OFF) + +if (MOONSHINE) + add_definitions(-DMOONSHINE) +endif() if(ENABLE_PROFILING) message(STATUS "Enable profiling support") @@ -129,6 +134,7 @@ set(SOURCES src/layers/wav2vec2.cc src/layers/wav2vec2bert.cc src/layers/whisper.cc + src/layers/moonshine.cc src/logging.cc src/models/language_model.cc src/models/model.cc @@ -139,6 +145,7 @@ set(SOURCES src/models/wav2vec2.cc src/models/wav2vec2bert.cc src/models/whisper.cc + src/models/moonshine.cc src/ops/activation.cc src/ops/add.cc src/ops/alibi_add.cc diff --git a/docs/conversion.md b/docs/conversion.md index 76ca28226..2990dace0 100644 --- a/docs/conversion.md +++ b/docs/conversion.md @@ -8,6 +8,7 @@ The Python module includes a [conversion API](python/ctranslate2.converters.rst) * [Fairseq](guides/fairseq.md) * [Marian](guides/marian.md) +* [Moonshine](guides/moonshine.md) * [OpenNMT-py](guides/opennmt_py.md) * [OpenNMT-tf](guides/opennmt_tf.md) * [OPUS-MT](guides/opus_mt.md) diff --git a/docs/guides/moonshine.md b/docs/guides/moonshine.md new file mode 100644 index 000000000..9a6ba6e44 --- /dev/null +++ b/docs/guides/moonshine.md @@ -0,0 +1,10 @@ +# Moonshine + +CTranslate2 supports [Moonshine](https://github.com/usefulsensors/moonshine) transcription models. The conversion requires the paths to the model and tokenizer.json files. + +Please use model.safetensor and tokenizer.json files from [Moonshine Tiny](https://huggingface.co/UsefulSensors/moonshine-tiny/tree/main) and [Moonshine Base](https://huggingface.co/UsefulSensors/moonshine-base/tree/main). + +```bash +ct2-moonshine-converter --model_path model.safetensors --vocab_path tokenizer.json --moonshine_variant tiny \ + --output_dir ct2_model +``` diff --git a/docs/requirements.txt b/docs/requirements.txt index be87cb592..3f80ced29 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,3 +1,4 @@ myst-parser==0.17.* sphinx-rtd-theme==1.0.* sphinx==4.5.* +safetensors[torch] diff --git a/include/ctranslate2/layers/moonshine.h b/include/ctranslate2/layers/moonshine.h new file mode 100644 index 000000000..bf58440b1 --- /dev/null +++ b/include/ctranslate2/layers/moonshine.h @@ -0,0 +1,75 @@ +#include "ctranslate2/layers/transformer.h" + +namespace ctranslate2 { + namespace layers { + + class MoonshinePreprocessor : public Layer { + public: + MoonshinePreprocessor(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& features, StorageView& output); + + DataType output_type() const override { + return _conv3.output_type(); + } + + dim_t output_size() const override { + return _conv3.output_size(); + } + + dim_t input_size() const { + return _conv1.input_size(); + } + private: + const Conv1D _conv1; + const ops::Tanh _tanh; + const LayerNorm _norm; + const Conv1D _conv2; + const ops::GELU _gelu1; + const Conv1D _conv3; + const ops::GELU _gelu2; + const ops::Transpose _transpose; + }; + + + class MoonshineEncoder : public Layer { + public: + MoonshineEncoder(const models::Model& model, const std::string& scope); + + void operator()(const StorageView& features, StorageView& output); + + DataType output_type() const override { + return _output_norm.output_type(); + } + + dim_t output_size() const override { + return _output_norm.output_size(); + } + + bool is_encoded(const StorageView& features) const { + // Input features shape: [batch_size, input_size, input_time] + // Encoder output shape: [batch_size, input_time // 2, output_size] + // + // input_time is variable so we check that dimension 1 is different than its original value. + + return (features.rank() == 3 + && features.dim(2) == output_size() + && features.dim(1) != 1); + } + + private: + const dim_t _num_heads; + const std::vector> _layers; + const LayerNorm _output_norm; + }; + + class MoonshineDecoder : public TransformerDecoder { + public: + using TransformerDecoder::TransformerDecoder; + + bool return_normalized_attention() const override { + return false; + } + }; + } +} diff --git a/include/ctranslate2/models/moonshine.h b/include/ctranslate2/models/moonshine.h new file mode 100644 index 000000000..baeb10bbc --- /dev/null +++ b/include/ctranslate2/models/moonshine.h @@ -0,0 +1,134 @@ +#pragma once + +#include "ctranslate2/generation.h" +#include "ctranslate2/layers/moonshine.h" +#include "ctranslate2/models/model.h" +#include "ctranslate2/replica_pool.h" + +namespace ctranslate2 { + namespace models { + + struct MoonshineOptions { + // Beam size to use for beam search (set 1 to run greedy search). + size_t beam_size = 5; + + // Beam search patience factor, as described in https://arxiv.org/abs/2204.05424. + // The decoding will continue until beam_size*patience hypotheses are finished. + float patience = 1; + + // Exponential penalty applied to the length during beam search. + float length_penalty = 1; + + // Penalty applied to the score of previously generated tokens, as described in + // https://arxiv.org/abs/1909.05858 (set > 1 to penalize). + float repetition_penalty = 1; + + // Prevent repetitions of ngrams with this size (set 0 to disable). + size_t no_repeat_ngram_size = 0; + + // Maximum generation length. + size_t max_length = 448; + + // Randomly sample from the top K candidates (set 0 to sample from the full distribution). + size_t sampling_topk = 1; + + // High temperatures increase randomness. + float sampling_temperature = 1; + + // Number of hypotheses to include in the result. + size_t num_hypotheses = 1; + + // Include scores in the result. + bool return_scores = false; + + // Suppress blank outputs at the beginning of the sampling. + bool suppress_blank = true; + + // List of token IDs to suppress. + // -1 will suppress a default set of symbols as defined in the model config.json file. + std::vector suppress_tokens = {-1}; + }; + + struct MoonshineGenerationResult { + std::vector> sequences; + std::vector> sequences_ids; + std::vector scores; + + size_t num_sequences() const { + return sequences.size(); + } + + bool has_scores() const { + return !scores.empty(); + } + }; + + class MoonshineModel : public Model { + public: + const Vocabulary& get_vocabulary() const; + + size_t current_spec_revision() const override; + bool is_quantizable(const std::string& variable_name) const override; + bool is_linear_weight(const std::string& variable_name) const override; + std::unique_ptr clone() const override; + + bool use_global_int16_scale() const override { + return false; + } + + protected: + void initialize(ModelReader& model_reader) override; + + private: + std::shared_ptr _vocabulary; + }; + + class MoonshineReplica : public ModelReplica { + public: + static std::unique_ptr create_from_model(const Model& model); + + MoonshineReplica(const std::shared_ptr& model); + + StorageView encode(StorageView features, const bool to_cpu); + + std::vector + generate(StorageView features, + const std::vector>& prompts, + const MoonshineOptions& options); + + std::vector + generate(StorageView features, + const std::vector>& prompts, + const MoonshineOptions& options); + + private: + const std::shared_ptr _model; + const std::unique_ptr _preprocessor; + const std::unique_ptr _encoder; + const std::unique_ptr _decoder; + + size_t _sot_id; + size_t _eot_id; + + StorageView maybe_encode(StorageView features); + }; + + class Moonshine : public ReplicaPool { + public: + using ReplicaPool::ReplicaPool; + + std::future encode(const StorageView& features, const bool to_cpu); + + std::vector> + generate(const StorageView& features, + std::vector> prompts, + MoonshineOptions options = {}); + + std::vector> + generate(const StorageView& features, + std::vector> prompts, + MoonshineOptions options = {}); + }; + + } +} diff --git a/include/ctranslate2/ops/layer_norm.h b/include/ctranslate2/ops/layer_norm.h index 05424fa67..14d0adf90 100644 --- a/include/ctranslate2/ops/layer_norm.h +++ b/include/ctranslate2/ops/layer_norm.h @@ -7,7 +7,7 @@ namespace ctranslate2 { class LayerNorm : public TernaryOp { public: - LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5); + LayerNorm(const dim_t axis = -1, const float epsilon = 1e-5, const bool multi_axis=false); using TernaryOp::operator(); void operator()(const StorageView& beta, @@ -32,10 +32,12 @@ namespace ctranslate2 { const dim_t outer_size, const dim_t axis_size, const dim_t inner_size, + const bool multi_axis, StorageView& output) const; const dim_t _axis; const float _epsilon; + const bool _multi_axis; }; } diff --git a/python/cpp/module.cc b/python/cpp/module.cc index 550aea5b2..203fe4fd2 100644 --- a/python/cpp/module.cc +++ b/python/cpp/module.cc @@ -89,4 +89,5 @@ PYBIND11_MODULE(_ext, m) ctranslate2::python::register_wav2vec2(m); ctranslate2::python::register_wav2vec2bert(m); ctranslate2::python::register_mpi(m); + ctranslate2::python::register_moonshine(m); } diff --git a/python/cpp/module.h b/python/cpp/module.h index 71d4b3b29..997a259eb 100644 --- a/python/cpp/module.h +++ b/python/cpp/module.h @@ -20,6 +20,7 @@ namespace ctranslate2 { void register_wav2vec2(py::module& m); void register_wav2vec2bert(py::module& m); void register_mpi(py::module& m); + void register_moonshine(py::module& m); } } diff --git a/python/cpp/moonshine.cc b/python/cpp/moonshine.cc new file mode 100644 index 000000000..b6394fcb2 --- /dev/null +++ b/python/cpp/moonshine.cc @@ -0,0 +1,237 @@ +#include "module.h" + +#include + +#include "replica_pool.h" + +namespace ctranslate2 { + namespace python { + + class MoonshineWrapper : public ReplicaPoolHelper { + public: + using ReplicaPoolHelper::ReplicaPoolHelper; + + StorageView encode(const StorageView& features, const bool to_cpu) { + return _pool->encode(features, to_cpu).get(); + } + + std::variant, + std::vector>> + generate(const StorageView& audio, + std::variant prompts, + bool asynchronous, + size_t beam_size, + float patience, + size_t num_hypotheses, + float length_penalty, + float repetition_penalty, + size_t no_repeat_ngram_size, + size_t max_length, + bool return_scores, + bool suppress_blank, + const std::optional>& suppress_tokens, + size_t sampling_topk, + float sampling_temperature) { + std::vector> futures; + + models::MoonshineOptions options; + options.beam_size = beam_size; + options.patience = patience; + options.length_penalty = length_penalty; + options.repetition_penalty = repetition_penalty; + options.no_repeat_ngram_size = no_repeat_ngram_size; + options.sampling_topk = sampling_topk; + options.sampling_temperature = sampling_temperature; + options.max_length = max_length; + options.num_hypotheses = num_hypotheses; + options.return_scores = return_scores; + options.suppress_blank = suppress_blank; + + if (suppress_tokens) + options.suppress_tokens = suppress_tokens.value(); + else + options.suppress_tokens.clear(); + std::shared_lock lock(_mutex); + assert_model_is_ready(); + + if (prompts.index() == 0) + futures = _pool->generate(audio, std::get(prompts), options); + else + futures = _pool->generate(audio, std::get(prompts), options); + + return maybe_wait_on_futures(std::move(futures), asynchronous); + } + + + }; + + + void register_moonshine(py::module& m) { + py::class_(m, "MoonshineGenerationResult", + "A generation result from the Moonshine model.") + + .def_readonly("sequences", &models::MoonshineGenerationResult::sequences, + "Generated sequences of tokens.") + .def_readonly("sequences_ids", &models::MoonshineGenerationResult::sequences_ids, + "Generated sequences of token IDs.") + .def_readonly("scores", &models::MoonshineGenerationResult::scores, + "Score of each sequence (empty if :obj:`return_scores` was disabled).") + + .def("__repr__", [](const models::MoonshineGenerationResult& result) { + return "MoonshineGenerationResult(sequences=" + std::string(py::repr(py::cast(result.sequences))) + + ", sequences_ids=" + std::string(py::repr(py::cast(result.sequences_ids))) + + ", scores=" + std::string(py::repr(py::cast(result.scores))) + + ")"; + }) + ; + + declare_async_wrapper(m, "MoonshineGenerationResultAsync"); + + py::class_( + m, "Moonshine", + R"pbdoc( + Implements the Moonshine speech recognition model published by Useful Sensors. + + See Also: + https://github.com/usefulsensors/moonshine + )pbdoc") + + .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, bool, py::object>(), + py::arg("model_path"), + py::arg("device")="cpu", + py::kw_only(), + py::arg("device_index")=0, + py::arg("compute_type")="default", + py::arg("inter_threads")=1, + py::arg("intra_threads")=0, + py::arg("max_queued_batches")=0, + py::arg("flash_attention")=false, + py::arg("tensor_parallel")=false, + py::arg("files")=py::none(), + R"pbdoc( + Initializes a Moonshine model from a converted model. + + Arguments: + model_path: Path to the CTranslate2 model directory. + device: Device to use (possible values are: cpu, cuda, auto). + device_index: Device IDs where to place this model on. + compute_type: Model computation type or a dictionary mapping a device name + to the computation type (possible values are: default, auto, int8, int8_float32, + int8_float16, int8_bfloat16, int16, float16, bfloat16, float32). + inter_threads: Number of workers to allow executing multiple batches in parallel. + intra_threads: Number of OpenMP threads per worker (0 to use a default value). + max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited, + 0 for an automatic value). When the queue is full, future requests will block + until a free slot is available. + flash_attention: run model with flash attention 2 for self-attention layer + tensor_parallel: run model with tensor parallel mode + files: Load model files from the memory. This argument is a dictionary mapping + file names to file contents as file-like or bytes objects. If this is set, + :obj:`model_path` acts as an identifier for this model. + )pbdoc") + + .def_property_readonly("device", &MoonshineWrapper::device, + "Device this model is running on.") + .def_property_readonly("device_index", &MoonshineWrapper::device_index, + "List of device IDs where this model is running on.") + .def_property_readonly("compute_type", &MoonshineWrapper::compute_type, + "Computation type used by the model.") + .def_property_readonly("num_workers", &MoonshineWrapper::num_replicas, + "Number of model workers backing this instance.") + .def_property_readonly("num_queued_batches", &MoonshineWrapper::num_queued_batches, + "Number of batches waiting to be processed.") + .def_property_readonly("tensor_parallel", &MoonshineWrapper::tensor_parallel, + "Run model with tensor parallel mode.") + .def_property_readonly("num_active_batches", &MoonshineWrapper::num_active_batches, + "Number of batches waiting to be processed or currently processed.") + + .def("encode", &MoonshineWrapper::encode, + py::arg("audio"), + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Encodes the input audio. + + Arguments: + audio: Audio, as a float array with shape + ``[batch_size, 1, audio_length]``. + to_cpu: Copy the encoder output to the CPU before returning the value. + + Returns: + The encoder output. + )pbdoc") + + .def("generate", &MoonshineWrapper::generate, + py::arg("audio"), + py::arg("prompts"), + py::kw_only(), + py::arg("asynchronous")=false, + py::arg("beam_size")=1, + py::arg("patience")=1, + py::arg("num_hypotheses")=1, + py::arg("length_penalty")=1, + py::arg("repetition_penalty")=1, + py::arg("no_repeat_ngram_size")=0, + py::arg("max_length")=448, + py::arg("return_scores")=false, + py::arg("suppress_blank")=true, + py::arg("suppress_tokens")=std::vector{-1}, + py::arg("sampling_topk")=1, + py::arg("sampling_temperature")=1, + py::call_guard(), + R"pbdoc( + Encodes the input features and generates from the given prompt. + + Arguments: + audio: Audio in [batch_size, audio_len] shaped array. + prompts: Prompt for model. Defaults to [' 1 to penalize). + no_repeat_ngram_size: Prevent repetitions of ngrams with this size + (set 0 to disable). + max_length: Maximum generation length. + return_scores: Include the scores in the output. + suppress_blank: Suppress blank outputs at the beginning of the sampling. + suppress_tokens: List of token IDs to suppress. -1 will suppress a default set + of symbols as defined in the model ``config.json`` file. + sampling_topk: Randomly sample predictions from the top K candidates. + sampling_temperature: Sampling temperature to generate more random samples. + + Returns: + A list of generation results. + )pbdoc") + + .def("unload_model", &MoonshineWrapper::unload_model, + py::arg("to_cpu")=false, + py::call_guard(), + R"pbdoc( + Unloads the model attached to this whisper but keep enough runtime context + to quickly resume whisper on the initial device. + + Arguments: + to_cpu: If ``True``, the model is moved to the CPU memory and not fully unloaded. + )pbdoc") + + .def("load_model", &MoonshineWrapper::load_model, + py::arg("keep_cache")=false, + py::call_guard(), + R"pbdoc( + Loads the model back to the initial device. + + Arguments: + keep_cache: If ``True``, the model cache in the CPU memory is not deleted if it exists. + )pbdoc") + + .def_property_readonly("model_is_loaded", &MoonshineWrapper::model_is_loaded, + "Whether the model is loaded on the initial device and ready to be used.") + ; + } + + } +} diff --git a/python/ctranslate2/converters/__init__.py b/python/ctranslate2/converters/__init__.py index eaccffee4..a15f876da 100644 --- a/python/ctranslate2/converters/__init__.py +++ b/python/ctranslate2/converters/__init__.py @@ -1,6 +1,7 @@ from ctranslate2.converters.converter import Converter from ctranslate2.converters.fairseq import FairseqConverter from ctranslate2.converters.marian import MarianConverter +from ctranslate2.converters.moonshine import MoonshineConverter from ctranslate2.converters.openai_gpt2 import OpenAIGPT2Converter from ctranslate2.converters.opennmt_py import OpenNMTPyConverter from ctranslate2.converters.opennmt_tf import OpenNMTTFConverter diff --git a/python/ctranslate2/converters/moonshine.py b/python/ctranslate2/converters/moonshine.py new file mode 100644 index 000000000..eae36ec5e --- /dev/null +++ b/python/ctranslate2/converters/moonshine.py @@ -0,0 +1,180 @@ +import argparse +import json +import re + +import numpy as np + +from safetensors.torch import safe_open + +from ctranslate2.converters import utils +from ctranslate2.converters.converter import Converter +from ctranslate2.specs import ( + TransformerDecoderSpec, + TransformerEncoderSpec, + TransformerSpec, +) +from ctranslate2.specs.common_spec import Activation +from ctranslate2.specs.moonshine_spec import MoonshineSpec + + +class MoonshineConverter(Converter): + def __init__(self, safetensor_file, vocab_file, moonshine_variant): + self.safetensor_file = safetensor_file + self.vocab_file = vocab_file + if moonshine_variant == "tiny": + self.layers = 6 + self.heads = 8 + elif moonshine_variant == "base": + self.layers = 8 + self.heads = 8 + else: + raise ValueError('moonshine_variant must be one of ["tiny", "base"]') + + def _load(self): + spec = MoonshineSpec( + num_encoder_layers=self.layers, + num_encoder_heads=self.heads, + num_decoder_layers=self.layers, + num_decoder_heads=self.heads, + ) + self.load_preprocessor(spec.preprocessor) + self.load_encoder(spec.encoder) + self.load_decoder(spec.decoder) + spec.register_vocabulary(self.load_vocab()) + return spec + + def load_vocab(self): + tokens_dict = {} + with open(self.vocab_file, encoding="utf-8") as f: + tokenizer_dict = json.load(f) + d = tokenizer_dict["model"]["vocab"] + for token in d.keys(): + idx = d[token] + token = re.sub(r"\\([^x])", r"\1", token) + token = token[1:-1] + if token.startswith("\\x"): + # Convert the digraph \x to the actual escaped sequence. + token = chr(int(token[2:], base=16)) + elif token.startswith("'") and token.endswith("'"): + token = token[1:-1] + token = token.replace("''", "'") + if idx is not None: + tokens_dict[idx] = token + added_tokens = tokenizer_dict["added_tokens"] + for t in added_tokens: + tokens_dict[t["id"]] = t["content"] + + return [tokens_dict[idx] for idx in sorted(tokens_dict.keys())] + + def load_attention(self, att_spec, st_prefix, self_attention=True): + st = safe_open(self.safetensor_file, framework="pt", device="cpu") + attn_w = [ + st.get_tensor(f"{st_prefix}.to_{dst}.weight") for dst in ["q", "k", "v"] + ] + if self_attention: + att_spec.linear[0].weight = np.concatenate(attn_w) + else: + att_spec.linear[0].weight = attn_w[0] + att_spec.linear[1].weight = np.concatenate(attn_w[1:]) + att_spec.linear[-1].weight = st.get_tensor(f"{st_prefix}.to_out.weight") + + def load_ffn(self, ffn_spec, st_prefix, swiglu=False): + st = safe_open(self.safetensor_file, framework="pt", device="cpu") + if swiglu: + ffn_spec.linear_0_noact.weight = st.get_tensor( + f"{st_prefix}.ff_noact.weight" + ) + ffn_spec.linear_0.weight = st.get_tensor(f"{st_prefix}.ff_proj.weight") + ffn_spec.linear_0_noact.bias = st.get_tensor(f"{st_prefix}.ff_noact.bias") + ffn_spec.linear_0.bias = st.get_tensor(f"{st_prefix}.ff_proj.bias") + ffn_spec.linear_1.weight = st.get_tensor(f"{st_prefix}.ff_out.weight") + ffn_spec.linear_1.bias = st.get_tensor(f"{st_prefix}.ff_out.bias") + else: + ffn_spec.linear_0.weight = st.get_tensor(f"{st_prefix}.ff.0.weight") + ffn_spec.linear_0.bias = st.get_tensor(f"{st_prefix}.ff.0.bias") + ffn_spec.linear_1.weight = st.get_tensor(f"{st_prefix}.ff.2.weight") + ffn_spec.linear_1.bias = st.get_tensor(f"{st_prefix}.ff.2.bias") + + def load_layernorm(self, ln_spec, ln_prefix): + st = safe_open(self.safetensor_file, framework="pt", device="cpu") + ln_spec.gamma = st.get_tensor(f"{ln_prefix}.weight") + ln_spec.beta = np.zeros(ln_spec.gamma.shape) + + def load_embeddings(self, embedding_spec, embedding_prefix): + st = safe_open(self.safetensor_file, framework="pt", device="cpu") + embedding_spec.weight = st.get_tensor(f"{embedding_prefix}.weight") + + def load_preprocessor(self, preprocess_spec): + st = safe_open(self.safetensor_file, framework="pt", device="cpu") + preprocess_prefix = "model.preprocessor.audio_preprocess" + preprocess_spec.conv1.weight = st.get_tensor(f"{preprocess_prefix}.0.weight") + preprocess_spec.layernorm.gamma = st.get_tensor(f"{preprocess_prefix}.2.weight") + preprocess_spec.layernorm.beta = st.get_tensor(f"{preprocess_prefix}.2.bias") + preprocess_spec.conv2.weight = st.get_tensor(f"{preprocess_prefix}.3.weight") + preprocess_spec.conv2.bias = st.get_tensor(f"{preprocess_prefix}.3.bias") + preprocess_spec.conv3.weight = st.get_tensor(f"{preprocess_prefix}.5.weight") + preprocess_spec.conv3.bias = st.get_tensor(f"{preprocess_prefix}.5.bias") + + def load_encoder(self, encoder_spec): + self.load_layernorm(encoder_spec.layer_norm, "model.encoder.post_norm") + for idx, l in enumerate(encoder_spec.layer): + self.load_attention( + l.self_attention, f"model.encoder.layers.{idx}.attention" + ) + self.load_layernorm( + l.self_attention.layer_norm, f"model.encoder.layers.{idx}.norm1" + ) + self.load_ffn(l.ffn, f"model.encoder.layers.{idx}.ff") + self.load_layernorm(l.ffn.layer_norm, f"model.encoder.layers.{idx}.norm2") + + def load_decoder(self, decoder_spec): + self.load_layernorm(decoder_spec.layer_norm, "model.decoder.final_norm") + self.load_embeddings(decoder_spec.embeddings, "model.decoder.token_embedding") + decoder_spec.projection.weight = decoder_spec.embeddings.weight + for idx, l in enumerate(decoder_spec.layer): + self.load_attention( + l.self_attention, f"model.decoder.layers.{idx}.self_attention" + ) + self.load_layernorm( + l.self_attention.layer_norm, f"model.decoder.layers.{idx}.norm1" + ) + self.load_attention( + l.attention, + f"model.decoder.layers.{idx}.cross_attention", + self_attention=False, + ) + self.load_layernorm( + l.attention.layer_norm, f"model.decoder.layers.{idx}.norm2" + ) + self.load_ffn(l.ffn, f"model.decoder.layers.{idx}.ff", swiglu=True) + self.load_layernorm(l.ffn.layer_norm, f"model.decoder.layers.{idx}.norm3") + + +def main(): + parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter + ) + parser.add_argument( + "--model_path", required=True, help="Path to the model .safetensor file." + ) + parser.add_argument( + "--vocab_path", + required=True, + help="Path to tokenizer.json config file.", + ) + parser.add_argument( + "--moonshine_variant", + required=True, + help="Moonshine variant to convert. Must be one of ['tiny', 'base']", + ) + + Converter.declare_arguments(parser) + args = parser.parse_args() + converter = MoonshineConverter( + args.model_path, args.vocab_path, args.moonshine_variant + ) + converter.convert_from_args(args) + + +if __name__ == "__main__": + main() diff --git a/python/ctranslate2/models/__init__.py b/python/ctranslate2/models/__init__.py index 35a3dca37..460d8570b 100644 --- a/python/ctranslate2/models/__init__.py +++ b/python/ctranslate2/models/__init__.py @@ -4,6 +4,7 @@ try: from ctranslate2._ext import ( + Moonshine, Wav2Vec2, Wav2Vec2Bert, Whisper, diff --git a/python/ctranslate2/specs/__init__.py b/python/ctranslate2/specs/__init__.py index b4e53fad2..54ba7557d 100644 --- a/python/ctranslate2/specs/__init__.py +++ b/python/ctranslate2/specs/__init__.py @@ -6,6 +6,7 @@ ModelSpec, SequenceToSequenceModelSpec, ) +from ctranslate2.specs.moonshine_spec import MoonshineSpec from ctranslate2.specs.transformer_spec import ( TransformerDecoderModelSpec, TransformerDecoderSpec, diff --git a/python/ctranslate2/specs/moonshine_spec.py b/python/ctranslate2/specs/moonshine_spec.py new file mode 100644 index 000000000..0ca076e66 --- /dev/null +++ b/python/ctranslate2/specs/moonshine_spec.py @@ -0,0 +1,84 @@ +from typing import List, Optional, Tuple + +import numpy as np + +from ctranslate2.specs import common_spec, model_spec, transformer_spec + + +class MoonshineConfig(model_spec.ModelConfig): + """Configuration for the Moonshine model.""" + + def __init__( + self, + suppress_ids: Optional[List[int]] = None, + suppress_ids_begin: Optional[List[int]] = None, + lang_ids: Optional[List[int]] = None, + alignment_heads: Optional[List[Tuple[int, int]]] = None, + ): + super().__init__( + suppress_ids=suppress_ids, + suppress_ids_begin=suppress_ids_begin, + lang_ids=lang_ids, + alignment_heads=alignment_heads, + ) + + +class MoonshineSpec(model_spec.LanguageModelSpec): + """Describes a Whisper model.""" + + def __init__( + self, + num_encoder_layers, + num_encoder_heads, + num_decoder_layers, + num_decoder_heads, + ): + """Initializes the model specification. + + Args: + num_encoder_layers: The number of encoder layers. + num_encoder_heads: The number of encoder attention heads. + num_decoder_layers: The number of decoder layers. + num_decoder_heads: The number of decoder attention heads. + """ + super().__init__() + self.preprocessor = AudioPreprocessSpec() + self.encoder = transformer_spec.TransformerEncoderSpec( + num_layers=num_encoder_layers, + num_heads=num_encoder_heads, + activation=common_spec.Activation.GELU, + num_source_embeddings=0, + rotary_dim=32, + ) + self.decoder = transformer_spec.TransformerDecoderSpec( + num_layers=num_decoder_layers, + num_heads=num_decoder_heads, + activation=common_spec.Activation.SWISH, + ffn_glu=True, + with_encoder_attention=True, + project_in_out=False, + rotary_dim=32, + ) + self.decoder.scale_embeddings = False + + @property + def name(self): + return "MoonshineSpec" + + @property + def revision(self): + return 0 + + def get_default_config(self): + return MoonshineConfig() + + def get_vocabulary_size(self): + return self.decoder.embeddings.weight.shape[0] + + +class AudioPreprocessSpec(model_spec.LayerSpec): + def __init__(self): + self.conv1 = common_spec.Conv1DSpec() + self.layernorm = common_spec.LayerNormSpec() + self.conv2 = common_spec.Conv1DSpec() + self.conv3 = common_spec.Conv1DSpec() diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 230e62cfd..60691d27c 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -22,6 +22,11 @@ def __init__( relative_attention_bias: bool = False, ffn_glu: bool = False, rms_norm: bool = False, + rotary_dim: Optional[int] = None, + rotary_interleave: bool = True, + rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None, + rotary_scaling_factor: float = 1, + rotary_base: float = 10000, multi_query_attention: bool = False, ): """Initializes a Transformer encoder specification. @@ -66,6 +71,11 @@ def __init__( relative_attention_bias=relative_attention_bias, ffn_glu=ffn_glu, rms_norm=rms_norm, + rotary_dim=rotary_dim, + rotary_interleave=rotary_interleave, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=rotary_base, num_heads_kv=1 if multi_query_attention else None, ) for _ in range(num_layers) @@ -251,6 +261,11 @@ def __init__( relative_attention_bias=False, ffn_glu=False, rms_norm=False, + rotary_dim=None, + rotary_interleave=True, + rotary_scaling_type=None, + rotary_scaling_factor=1, + rotary_base=10000, num_heads_kv=None, sliding_window=None, ): @@ -259,6 +274,11 @@ def __init__( relative_position=relative_position, relative_attention_bias=relative_attention_bias, rms_norm=rms_norm, + rotary_dim=rotary_dim, + rotary_interleave=rotary_interleave, + rotary_scaling_type=rotary_scaling_type, + rotary_scaling_factor=rotary_scaling_factor, + rotary_base=rotary_base, num_heads_kv=num_heads_kv, sliding_window=sliding_window, ) diff --git a/python/setup.py b/python/setup.py index 7f56d6074..ae067bfa4 100644 --- a/python/setup.py +++ b/python/setup.py @@ -110,6 +110,7 @@ def _maybe_add_library_root(lib_name): "console_scripts": [ "ct2-fairseq-converter=ctranslate2.converters.fairseq:main", "ct2-marian-converter=ctranslate2.converters.marian:main", + "ct2-moonshine-converter=ctranslate2.converters.moonshine:main", "ct2-openai-gpt2-converter=ctranslate2.converters.openai_gpt2:main", "ct2-opennmt-py-converter=ctranslate2.converters.opennmt_py:main", "ct2-opennmt-tf-converter=ctranslate2.converters.opennmt_tf:main", diff --git a/python/tests/requirements.txt b/python/tests/requirements.txt index c5f4812a7..663819be2 100644 --- a/python/tests/requirements.txt +++ b/python/tests/requirements.txt @@ -6,3 +6,4 @@ tensorflow-cpu==2.11.* pytest wurlitzer==3.0.*;platform_system=='Linux' torch==2.2.0 +safetensors diff --git a/src/cpu/kernels.cc b/src/cpu/kernels.cc index c1f48553d..4220c485d 100644 --- a/src/cpu/kernels.cc +++ b/src/cpu/kernels.cc @@ -465,6 +465,7 @@ namespace ctranslate2 { const float* beta, float* output, dim_t batch_size, + dim_t weights_size, dim_t depth, float epsilon) { parallel_for(0, batch_size, 1, [&](dim_t begin, dim_t end) { @@ -488,8 +489,12 @@ namespace ctranslate2 { const float variance = std::max(sum_squares / depth - mean * mean, 0.f); const float rstd = 1.f / std::sqrt(variance + epsilon); - for (dim_t j = 0; j < depth; ++j) { - y[j] = (x[j] - mean) * rstd * gamma[j] + beta[j]; + int inner_dim = depth / weights_size; + for (dim_t j = 0; j < inner_dim; j ++) { + for (dim_t k = 0; k < weights_size; k++) { + int idx = k * inner_dim + j; + y[idx] = (x[idx] - mean) * rstd * gamma[k] + beta[k]; + } } } }); diff --git a/src/cpu/kernels.h b/src/cpu/kernels.h index 16296fc36..51cb36182 100644 --- a/src/cpu/kernels.h +++ b/src/cpu/kernels.h @@ -78,6 +78,7 @@ namespace ctranslate2 { const float* beta, float* output, dim_t batch_size, + dim_t weights_size, dim_t depth, float epsilon); diff --git a/src/layers/common.cc b/src/layers/common.cc index c6d1cd0b5..5f5069786 100644 --- a/src/layers/common.cc +++ b/src/layers/common.cc @@ -453,7 +453,10 @@ namespace ctranslate2 { } void LayerNorm::operator()(const StorageView& input, StorageView& output) const { - if (_beta) { + if (_gamma.size() != input.dim(input.rank() - 1) && _gamma.size() == input.dim(input.rank() - 2)) { + const ops::LayerNorm norm_op(-2, _epsilon, true); + norm_op(*_beta, _gamma, input, output); + } else if (_beta) { const ops::LayerNorm norm_op(-1, _epsilon); norm_op(*_beta, _gamma, input, output); } else { diff --git a/src/layers/moonshine.cc b/src/layers/moonshine.cc new file mode 100644 index 000000000..51afeead3 --- /dev/null +++ b/src/layers/moonshine.cc @@ -0,0 +1,70 @@ +#include "ctranslate2/layers/moonshine.h" + +namespace ctranslate2 { + namespace layers { + MoonshinePreprocessor::MoonshinePreprocessor(const models::Model& model, const std::string& scope) + : _conv1(model, scope + "/conv1", /*stride=*/64, /*padding=*/0), + _tanh(), + _norm(model, scope + "/layernorm"), + _conv2(model, scope + "/conv2", /*stride=*/3, /*padding=*/0), + _gelu1(), + _conv3(model, scope + "/conv3", /*stride=*/2, /*padding=*/0), + _gelu2(), + _transpose({0, 2, 1}) {} + + void MoonshinePreprocessor::operator()(const StorageView& features, StorageView& output) { + if (features.rank() != 2) + throw std::invalid_argument("Expected input features to have 2 dimensions, but got " + + std::to_string(features.rank()) + + " dimension(s) instead"); + + StorageView input(output_type(), features.device()); + StorageView input_reshaped = std::move(features); + input_reshaped.expand_dims(1); + + _conv1(input_reshaped, input); + _tanh(input, input); + _norm(input, input); + + _conv2(input, output); + _gelu1(output, output); + + _conv3(output, input); + _gelu2(input, input); + _transpose(input, output); + } + + + MoonshineEncoder::MoonshineEncoder(const models::Model& model, const std::string& scope) + : _num_heads(model.get_attribute_with_default(scope + "/num_heads", 8)) + , _layers(build_layers_list(model, + scope + "/layer", + _num_heads, + /*pre_norm=*/true, + ops::ActivationType::GELU)) + , _output_norm(model, scope + "/layer_norm") + { + } + + void MoonshineEncoder::operator()(const StorageView& features, StorageView& output) { + PROFILE("MoonshineEncoder"); + + if (features.rank() != 3) + throw std::invalid_argument("Expected input features to have 3 dimensions, but got " + + std::to_string(features.rank()) + + " dimension(s) instead"); + + StorageView input(output_type(), features.device()); + + input = std::move(features); + + for (const auto& layer : _layers) { + (*layer)(input, nullptr, output); + input = std::move(output); + } + + _output_norm(input, output); + } + + } +} diff --git a/src/models/model.cc b/src/models/model.cc index b8e1c2d8f..c2bb8e3f6 100644 --- a/src/models/model.cc +++ b/src/models/model.cc @@ -213,6 +213,9 @@ namespace ctranslate2 { if (device == Device::CUDA #ifdef CT2_WITH_DNNL || true +#endif +#ifdef MOONSHINE + || true #endif ) { variable_weight_dtype = float_dtype; diff --git a/src/models/model_factory.cc b/src/models/model_factory.cc index 059051f5d..ac2404eda 100644 --- a/src/models/model_factory.cc +++ b/src/models/model_factory.cc @@ -3,6 +3,7 @@ #include #include "ctranslate2/models/whisper.h" +#include "ctranslate2/models/moonshine.h" #include "ctranslate2/models/wav2vec2.h" #include "ctranslate2/models/wav2vec2bert.h" #include "ctranslate2/models/transformer.h" @@ -26,6 +27,8 @@ namespace ctranslate2 { register_model("Wav2Vec2Spec"); register_model("Wav2Vec2BertSpec"); + + register_model("MoonshineSpec"); } std::shared_ptr create_model(const std::string& name) { diff --git a/src/models/moonshine.cc b/src/models/moonshine.cc new file mode 100644 index 000000000..db9179f14 --- /dev/null +++ b/src/models/moonshine.cc @@ -0,0 +1,234 @@ +#include "ctranslate2/models/moonshine.h" + +#include + +#include "ctranslate2/decoding.h" + +#include "dispatch.h" +#include "dtw.h" + +#ifdef CT2_WITH_CUDA +# include "cuda/utils.h" +#endif + +namespace ctranslate2 { + namespace models { + + const Vocabulary& MoonshineModel::get_vocabulary() const { + return *_vocabulary; + } + + size_t MoonshineModel::current_spec_revision() const { + return 0; + } + + void MoonshineModel::initialize(ModelReader& model_reader) { + VocabularyInfo vocab_info; + vocab_info.unk_token = ""; + vocab_info.bos_token = ""; + vocab_info.eos_token = ""; + + _vocabulary = load_vocabulary(model_reader, "vocabulary", std::move(vocab_info)); + if (!_vocabulary) + throw std::runtime_error("Cannot load the vocabulary from the model directory"); + } + + bool MoonshineModel::is_quantizable(const std::string& variable_name) const { + return Model::is_quantizable(variable_name); + } + + bool MoonshineModel::is_linear_weight(const std::string& variable_name) const { + return is_quantizable(variable_name) && variable_name.find("embeddings") == std::string::npos; + } + + std::unique_ptr MoonshineModel::clone() const { + return std::make_unique(*this); + } + + + std::unique_ptr MoonshineReplica::create_from_model(const Model& model) { + if (!dynamic_cast(&model)) + throw std::invalid_argument("The model is not a Moonshine model"); + + const auto scoped_device_setter = model.get_scoped_device_setter(); + const auto model_ptr = model.shared_from_this(); + const auto concrete_model = std::static_pointer_cast(model_ptr); + return std::make_unique(concrete_model); + } + + MoonshineReplica::MoonshineReplica(const std::shared_ptr& model) + : ModelReplica(model) + , _model(model) + , _preprocessor(std::make_unique(*model, "preprocessor")) + , _encoder(std::make_unique(*model, "encoder")) + , _decoder(std::make_unique(*model, "decoder")) + { + const auto& vocabulary = model->get_vocabulary(); + _sot_id = vocabulary.bos_id(); + _eot_id = vocabulary.eos_id(); + } + + StorageView MoonshineReplica::encode(StorageView features, const bool to_cpu) { + PROFILE("MoonshineReplica::encode"); + +#ifdef CT2_WITH_CUDA + const cuda::UseTrueFp16GemmInScope use_true_fp16_gemm(false); +#endif + + const auto scoped_device_setter = _model->get_scoped_device_setter(); + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + features.move_to(device, dtype); + + StorageView encoder_input(dtype, device); + StorageView encoder_output(dtype, device); + (*_preprocessor)(features, encoder_input); + (*_encoder)(encoder_input, encoder_output); + + if (to_cpu) { + if (device != Device::CPU) + encoder_output = encoder_output.to(Device::CPU); + return encoder_output; + } + + // Ensure all operations are finished before returning the output. + synchronize_stream(device); + + return encoder_output; + } + + StorageView MoonshineReplica::maybe_encode(StorageView features) { + const Device device = _model->device(); + const DataType dtype = _encoder->output_type(); + + features.move_to(device, dtype); + + if (_encoder->is_encoded(features)) + return features; + + StorageView encoder_input(dtype, device); + StorageView encoder_output(dtype, device); + (*_preprocessor)(features, encoder_input); + (*_encoder)(encoder_input, encoder_output); + return encoder_output; + } + + std::vector + MoonshineReplica::generate(StorageView features, + const std::vector>& prompts, + const MoonshineOptions& options) { + const auto& vocabulary = _model->get_vocabulary(); + return generate(std::move(features), vocabulary.to_ids(prompts), options); + } + + std::vector + MoonshineReplica::generate(StorageView features, + const std::vector>& prompts, + const MoonshineOptions& options) { + PROFILE("MoonshineReplica::generate"); + if (prompts.empty()) + return {}; + +#ifdef CT2_WITH_CUDA + const cuda::UseTrueFp16GemmInScope use_true_fp16_gemm(false); +#endif + + const auto& vocabulary = _model->get_vocabulary(); + const auto scoped_device_setter = _model->get_scoped_device_setter(); + + layers::DecoderState state = _decoder->initial_state(); + state.emplace("memory", maybe_encode(std::move(features))); + + _decoder->update_output_layer(_model->preferred_size_multiple()); + + const dim_t total_max_length = options.max_length; + + DecodingOptions decoding_options; + decoding_options.start_step = 0; + decoding_options.beam_size = options.beam_size; + decoding_options.patience = options.patience; + decoding_options.length_penalty = options.length_penalty; + decoding_options.repetition_penalty = options.repetition_penalty; + decoding_options.no_repeat_ngram_size = options.no_repeat_ngram_size; + decoding_options.max_length = total_max_length; + decoding_options.sampling_topk = options.sampling_topk; + decoding_options.sampling_temperature = options.sampling_temperature; + decoding_options.num_hypotheses = options.num_hypotheses; + decoding_options.return_scores = options.return_scores; + decoding_options.include_eos_in_hypotheses = false; + + for (const auto& id : options.suppress_tokens) { + if (id >= 0) + decoding_options.disable_ids.push_back(id); + else if (id == -1) { + for (const auto& default_id : _model->config["suppress_ids"]) + decoding_options.disable_ids.push_back(default_id); + } + } + + if (options.suppress_blank) { + for (const auto& id : _model->config["suppress_ids_begin"]) + decoding_options.disable_ids_begin.push_back(id); + } + + std::vector results = decode(*_decoder, + state, + prompts, + {_eot_id}, + decoding_options); + + std::vector final_results; + final_results.reserve(results.size()); + + for (size_t i = 0; i < results.size(); ++i) { + auto& result = results[i]; + + MoonshineGenerationResult final_result; + final_result.sequences = vocabulary.to_tokens(result.hypotheses); + final_result.sequences_ids = std::move(result.hypotheses); + final_result.scores = std::move(result.scores); + + final_results.emplace_back(std::move(final_result)); + } + + return final_results; + } + + std::future Moonshine::encode(const StorageView& features, const bool to_cpu) { + return post( + [features = features.sync_copy(), to_cpu](MoonshineReplica& replica) mutable { + return replica.encode(std::move(features), to_cpu); + }); + } + + std::vector> + Moonshine::generate(const StorageView& features, + std::vector> prompts, + MoonshineOptions options) { + const size_t batch_size = features.dim(0); + return post_batch( + [features = features.sync_copy(), + prompts = std::move(prompts), + options = std::move(options)] + (MoonshineReplica& replica) mutable { + return replica.generate(std::move(features), prompts, options); + }, + batch_size); + } + + std::vector> + Moonshine::generate(const StorageView& features, + std::vector> prompts, + MoonshineOptions options) { + const size_t batch_size = features.dim(0); + return post_batch( + [features = features.sync_copy(), + prompts = std::move(prompts), + options = std::move(options)] + (MoonshineReplica& replica) mutable { + return replica.generate(std::move(features), prompts, options); + }, + batch_size); + } + } +} diff --git a/src/ops/layer_norm.cc b/src/ops/layer_norm.cc index f21fe0b24..521ffc2e5 100644 --- a/src/ops/layer_norm.cc +++ b/src/ops/layer_norm.cc @@ -5,9 +5,10 @@ namespace ctranslate2 { namespace ops { - LayerNorm::LayerNorm(const dim_t axis, const float epsilon) + LayerNorm::LayerNorm(const dim_t axis, const float epsilon, const bool multi_axis) : _axis(axis) , _epsilon(epsilon) + , _multi_axis(multi_axis) { } @@ -51,6 +52,7 @@ namespace ctranslate2 { outer_size, axis_size, inner_size, + _multi_axis, output))); } diff --git a/src/ops/layer_norm_cpu.cc b/src/ops/layer_norm_cpu.cc index 60441a265..17ef48435 100644 --- a/src/ops/layer_norm_cpu.cc +++ b/src/ops/layer_norm_cpu.cc @@ -13,6 +13,7 @@ namespace ctranslate2 { const dim_t outer_size, const dim_t axis_size, const dim_t inner_size, + const bool multi_axis, StorageView& output) const { if (axis == input.rank() - 1 && beta && gamma) { CPU_ISA_DISPATCH((cpu::layer_norm(input.data(), @@ -20,8 +21,18 @@ namespace ctranslate2 { beta->data(), output.data(), outer_size, + gamma->size(), axis_size, _epsilon))); + } else if (multi_axis && axis != input.rank() - 1 && beta && gamma) { + CPU_ISA_DISPATCH((cpu::layer_norm(input.data(), + gamma->data(), + beta->data(), + output.data(), + outer_size, + gamma->size(), + axis_size * inner_size, + _epsilon))); } else { CPU_ISA_DISPATCH((cpu::layer_norm_axis(input.data(), gamma ? gamma->data() : nullptr, @@ -43,6 +54,7 @@ namespace ctranslate2 { const dim_t outer_size, \ const dim_t axis_size, \ const dim_t inner_size, \ + const bool multi_axis, \ StorageView& output) const; DECLARE_IMPL(float) diff --git a/src/ops/layer_norm_gpu.cu b/src/ops/layer_norm_gpu.cu index 8c644d876..893289f5b 100644 --- a/src/ops/layer_norm_gpu.cu +++ b/src/ops/layer_norm_gpu.cu @@ -9,6 +9,7 @@ namespace at { // Forward declaration of the CUDA kernels. template __global__ void LayerNormForwardCUDAKernel(SizeT N, + SizeT axis_size, float eps, const T* X, const T* gamma, @@ -30,13 +31,15 @@ namespace ctranslate2 { const dim_t axis, const dim_t outer_size, const dim_t axis_size, - const dim_t, + const dim_t inner_size, + const bool multi_axis, StorageView& output) const { - if (axis != input.rank() - 1 || !beta || !gamma) + if (!multi_axis && axis != input.rank() - 1 || !beta || !gamma) throw std::invalid_argument("Generalized LayerNorm is currently not implemented on GPU"); at::native::LayerNormForwardCUDAKernel, cuda::index_t> <<>>( + inner_size * axis_size, axis_size, _epsilon, cuda::device_cast(input.data()), @@ -54,6 +57,7 @@ namespace ctranslate2 { const dim_t outer_size, \ const dim_t axis_size, \ const dim_t inner_size, \ + const bool multi_axis, \ StorageView& output) const; DECLARE_IMPL(float) @@ -147,6 +151,7 @@ namespace at { template __global__ void LayerNormForwardCUDAKernel(SizeT N, + SizeT axis_size, float eps, const T* X, const T* gamma, @@ -179,11 +184,13 @@ namespace at { __syncthreads(); - for (SizeT j = threadIdx.x; j < N; j += blockDim.x) { - const SizeT index = i * N + j; - Y[index] = (float(X[index]) - s_mean) * s_variance * float(gamma[j]) + float(beta[j]); + SizeT inner_dim = N / axis_size; + for (SizeT j = 0; j < inner_dim; j++) { + for (SizeT k = threadIdx.x; k < axis_size; k += blockDim.x) { + const SizeT index = i * N + k * inner_dim + j; + Y[index] = (float(X[index]) - s_mean) * s_variance * float(gamma[k]) + float(beta[k]); + } } } - } }