diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index 0e6e6dc..2fc7052 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -17,16 +17,6 @@ from moshi.models.loaders import CheckpointInfo from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel -def audio_to_int16(audio: np.ndarray) -> np.ndarray: - if audio.dtype == np.int16: - return audio - elif audio.dtype == np.float32: - # Multiply by 32767 and not 32768 so that int16 doesn't overflow. - return (np.clip(audio, -1, 1) * 32767).astype(np.int16) - else: - raise TypeError(f"Unsupported audio data type: {audio.dtype}") - - def play_audio(audio: np.ndarray, sample_rate: int): # Requires the Portaudio library which might not be available in all environments. import sounddevice as sd @@ -86,7 +76,8 @@ def main(): ) print("Generating audio...") - # This doesn't do streaming generation, + # This doesn't do streaming generation, but the model allows it. For now, see Rust + # example. result = tts_model.generate([entries], [condition_attributes]) frames = torch.cat(result.frames, dim=-1) diff --git a/tts_pytorch.ipynb b/tts_pytorch.ipynb new file mode 100644 index 0000000..9c89892 --- /dev/null +++ b/tts_pytorch.ipynb @@ -0,0 +1,164 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "0b7eed16", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "353b9498", + "metadata": {}, + "outputs": [], + "source": [ + "import argparse\n", + "import sys\n", + "\n", + "import numpy as np\n", + "import sphn\n", + "import torch\n", + "from moshi.models.loaders import CheckpointInfo\n", + "from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n", + "\n", + "from IPython.display import display, Audio" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8846418a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "See https://huggingface.co/datasets/kyutai/tts-voices for available voices.\n" + ] + } + ], + "source": [ + "# Configuration\n", + "text = \"Hey there! How are you? I had the craziest day today.\"\n", + "voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n", + "print(f\"See https://huggingface.co/datasets/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "b9f022ec", + "metadata": {}, + "outputs": [], + "source": [ + "# Set everything up\n", + "checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n", + "tts_model = TTSModel.from_checkpoint_info(\n", + " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n", + ")\n", + "\n", + "# You could also generate multiple audios at once by passing a list of texts.\n", + "entries = tts_model.prepare_script([text], padding_between=1)\n", + "voice_path = tts_model.get_voice_path(voice)\n", + "# CFG coef goes here because the model was trained with CFG distillation,\n", + "# so it's not _actually_ doing CFG at inference time.\n", + "condition_attributes = tts_model.make_condition_attributes(\n", + " [voice_path], cfg_coef=2.0\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f4f76c73", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generating audio...\n" + ] + } + ], + "source": [ + "print(\"Generating audio...\")\n", + "\n", + "# This doesn't do streaming generation,\n", + "result = tts_model.generate([entries], [condition_attributes])\n", + "\n", + "frames = torch.cat(result.frames, dim=-1)\n", + "audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]\n", + "with torch.no_grad():\n", + " audios = tts_model.mimi.decode(audio_tokens)\n", + "\n", + "audio = audios[0].cpu().numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "732e4b4b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "display(\n", + " Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2dbdd275", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.2" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}