diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py
index 0e6e6dc..2fc7052 100644
--- a/scripts/tts_pytorch.py
+++ b/scripts/tts_pytorch.py
@@ -17,16 +17,6 @@ from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
-def audio_to_int16(audio: np.ndarray) -> np.ndarray:
- if audio.dtype == np.int16:
- return audio
- elif audio.dtype == np.float32:
- # Multiply by 32767 and not 32768 so that int16 doesn't overflow.
- return (np.clip(audio, -1, 1) * 32767).astype(np.int16)
- else:
- raise TypeError(f"Unsupported audio data type: {audio.dtype}")
-
-
def play_audio(audio: np.ndarray, sample_rate: int):
# Requires the Portaudio library which might not be available in all environments.
import sounddevice as sd
@@ -86,7 +76,8 @@ def main():
)
print("Generating audio...")
- # This doesn't do streaming generation,
+ # This doesn't do streaming generation, but the model allows it. For now, see Rust
+ # example.
result = tts_model.generate([entries], [condition_attributes])
frames = torch.cat(result.frames, dim=-1)
diff --git a/tts_pytorch.ipynb b/tts_pytorch.ipynb
new file mode 100644
index 0000000..9c89892
--- /dev/null
+++ b/tts_pytorch.ipynb
@@ -0,0 +1,164 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "0b7eed16",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "353b9498",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import argparse\n",
+ "import sys\n",
+ "\n",
+ "import numpy as np\n",
+ "import sphn\n",
+ "import torch\n",
+ "from moshi.models.loaders import CheckpointInfo\n",
+ "from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel\n",
+ "\n",
+ "from IPython.display import display, Audio"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "8846418a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "See https://huggingface.co/datasets/kyutai/tts-voices for available voices.\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Configuration\n",
+ "text = \"Hey there! How are you? I had the craziest day today.\"\n",
+ "voice = \"expresso/ex03-ex01_happy_001_channel1_334s.wav\"\n",
+ "print(f\"See https://huggingface.co/datasets/{DEFAULT_DSM_TTS_VOICE_REPO} for available voices.\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "b9f022ec",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set everything up\n",
+ "checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
+ "tts_model = TTSModel.from_checkpoint_info(\n",
+ " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n",
+ ")\n",
+ "\n",
+ "# You could also generate multiple audios at once by passing a list of texts.\n",
+ "entries = tts_model.prepare_script([text], padding_between=1)\n",
+ "voice_path = tts_model.get_voice_path(voice)\n",
+ "# CFG coef goes here because the model was trained with CFG distillation,\n",
+ "# so it's not _actually_ doing CFG at inference time.\n",
+ "condition_attributes = tts_model.make_condition_attributes(\n",
+ " [voice_path], cfg_coef=2.0\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "f4f76c73",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Generating audio...\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Generating audio...\")\n",
+ "\n",
+ "# This doesn't do streaming generation,\n",
+ "result = tts_model.generate([entries], [condition_attributes])\n",
+ "\n",
+ "frames = torch.cat(result.frames, dim=-1)\n",
+ "audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]\n",
+ "with torch.no_grad():\n",
+ " audios = tts_model.mimi.decode(audio_tokens)\n",
+ "\n",
+ "audio = audios[0].cpu().numpy()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "732e4b4b",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "display(\n",
+ " Audio(audio, rate=tts_model.mimi.sample_rate, autoplay=True)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2dbdd275",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.13.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}