Compare commits

...

23 Commits

Author SHA1 Message Date
96042c9983 install.sh hinzugefügt
Some checks failed
precommit / Run precommit (push) Has been cancelled
2025-08-12 18:03:43 +00:00
7e92c6f28a app/scripts/tts_runner.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:59:20 +00:00
8e51c6eab9 app/dependency_check.py aktualisiert
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:58:17 +00:00
34e58b8055 dependency_check.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:57:33 +00:00
f23d1be027 app/api_server.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:56:32 +00:00
Laurent Mazare
cf97f8d863
Workaround for the mlx kv-cache bug. (#108) 2025-08-04 16:37:00 +02:00
Laurent Mazare
09468c239a
Print the duration of the audio generated so far. (#107) 2025-08-04 09:24:31 +02:00
Laurent Mazare
07729ed47e
Use the proper repos when vad is on. (#103) 2025-08-01 15:55:49 +02:00
Laurent Mazare
af2283de3f
Use a streaming input in the rust example. (#102)
* Use a streaming input in the rust example.

* Formatting.

* Another formatting tweak.
2025-07-31 17:41:57 +02:00
Laurent Mazare
7dc926d50c
Allow for using local voices in the pytorch examples. (#100) 2025-07-31 12:48:05 +02:00
Laurent Mazare
ab8e8c59b7
Bump the version numbers. (#91) 2025-07-19 15:57:53 +02:00
laurent
5f17114618 More faq. 2025-07-18 08:43:10 +02:00
laurent
405a82ba3f FAQ tweaks. 2025-07-18 08:37:51 +02:00
Laurent Mazare
3b584b100c
Sketch a FAQ and add some issue templates. (#88) 2025-07-18 08:31:31 +02:00
Laurent Mazare
a98eb94ade
Add a streaming example for the mlx tts. (#85)
* Add a streaming example for the mlx tts.

* Fix the CI.

* Formatting fix.

* Yet another CI fix.
2025-07-16 22:35:44 +02:00
Laurent Mazare
a2f031deb5
Fix the pytorch tts streaming example. (#84)
* Fix the pytorch tts streaming example.

* Edit the readme too.
2025-07-16 21:07:02 +02:00
Laurent Mazare
66a33c989f
Add a TTS streaming example. (#83)
* Add a TTS streaming example.

* Get the streaming example to work.
2025-07-16 20:54:13 +02:00
laurent
89a2ced839 Bump the moshi version. 2025-07-16 16:02:28 +02:00
Laurent
baf0c75bba VAD support in the mlx-stt example that uses the microphone. 2025-07-08 16:08:32 +02:00
Laurent Mazare
952319de90
Add a MLX STT example that uses VAD. (#70)
* Add a MLX STT example that uses VAD.

* VAD support.

* More MLX VAD example.

* Use the latest moshi-mlx.
2025-07-08 16:04:50 +02:00
laurent
6d3bb6b1f1 Avoid the config override for the extra-heads. 2025-07-08 15:39:17 +02:00
Laurent Mazare
12dbe36b0b
Add some VAD to the pytorch speech-to-text example. (#68) 2025-07-08 11:30:34 +02:00
Václav Volhejn
cafac63222
Run pre-commit correctly in CI (#66)
* fix and break

* Remove intentional error
2025-07-08 10:11:52 +02:00
19 changed files with 1510 additions and 29 deletions

83
.github/ISSUE_TEMPLATE/bug.yml vendored Normal file
View File

@ -0,0 +1,83 @@
name: Bug Report
description: You found a bug.
labels: ["bug", "triage"]
body:
- type: markdown
attributes:
value: |
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
- type: dropdown
id: backend
attributes:
label: Backend impacted
description: Which backend is concerned with your bug report?
options:
- The PyTorch implementation
- The MLX implementation
- The Rust implementation
- Other / All
default: 0
validations:
required: true
- type: dropdown
id: os
attributes:
label: Operating system
description: What is your operating system?
options:
- Linux
- Mac OS X
- Windows (unsupported)
default: 0
validations:
required: true
- type: dropdown
id: hardware
attributes:
label: Hardware
description: What hardware are you using?
options:
- CPU
- GPU with CUDA
- Metal with MLX
default: 0
validations:
required: true
- type: textarea
id: description
attributes:
label: Description
description: Provide a detailed description of your bug.
placeholder:
value:
validations:
required: true
- type: textarea
id: more_info
attributes:
label: Extra information
description: Please provide any other relevant information, such as log extracts, code etc.
placeholder:
value:
validations:
required: true
- type: textarea
id: env
attributes:
label: Environment
description: Please provide any other relevant information, such as log extracts, code etc.
placeholder:
value: |
Fill in the following information on your system.
- Operating system version:
If the backend impacted is PyTorch:
- Python version:
- PyTorch version:
- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
- GPU model and memory:
If the backend is MLX:
- Mac model:
validations:
required: true

40
.github/ISSUE_TEMPLATE/question.yml vendored Normal file
View File

@ -0,0 +1,40 @@
name: Question
description: You have a question about the codebase, the paper, or the implementation.
labels: ["question", "triage"]
body:
- type: markdown
attributes:
value: |
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
- type: checkboxes
id: terms
attributes:
label: Due diligence
description: Have you searched the existing issues / FAQ / Google / asked ChatGPT?
options:
- label: I have done my due diligence in trying to find the answer myself.
required: true
- type: dropdown
id: backend
attributes:
label: Topic
description: What is your question about?
options:
- The paper
- The PyTorch implementation
- The MLX implementation
- The Rust implementation
- Other / All
default: 0
validations:
required: true
- type: textarea
id: question
attributes:
label: Question
description: What is your question?
placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc.
value:
validations:
required: true

View File

@ -13,5 +13,5 @@ jobs:
- uses: actions/checkout@v2
- uses: ./.github/actions/moshi_build
- run: |
. env/bin/activate
bash .git/hooks/pre-commit
source env/bin/activate
pre-commit run --all-files

1
.gitignore vendored
View File

@ -192,3 +192,4 @@ cython_debug/
# refer to https://docs.cursor.com/context/ignore-files
.cursorignore
.cursorindexingignore
out*.wav

56
FAQ.md Normal file
View File

@ -0,0 +1,56 @@
# FAQ
Here is the answer to a number of frequently asked questions.
### Torch compilation issues
With some PyTorch/triton versions, one might encounter compilation errors
like the following:
```
Traceback (most recent call last):
...
File "site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1153, in make_launcher
"launch_enter_hook": binary.__class__.launch_enter_hook,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch._inductor.exc.InductorError: AttributeError: type object 'CompiledKernel' has no attribute 'launch_enter_hook'
```
If that's the case, you can disable torch compilation by setting the following
environment variable.
```bash
export NO_TORCH_COMPILE=1
```
### Issues installing the sentencepiece dependency
On some linux distributions (arch) or on macos, the local version of cmake can
be too recent for the sentencepiece dependency.
```
CMake Error at CMakeLists.txt:15 (cmake_minimum_required):
Compatibility with CMake < 3.5 has been removed from CMake.
```
You can either downgrade your cmake version, e.g. 3.31.0 on arch works or try
setting `CMAKE_POLICY_VERSION_MINIMUM=3.5`.
If you run into some errors when compiling the sentencepiece rust bindings,
these could also be due to gcc being too recent, e.g. gcc 15. You can get
around this by using gcc-13, e.g. by setting the following after installing
the proper gcc packages.
```bash
export CMAKE_C_COMPILER=/usr/bin/gcc-13
export CMAKE_CXX_COMPILER=/usr/bin/g++-13
CC=gcc-13 CXX=g++-13 cargo build --release
```
Alternatively you can set `CXXFLAGS="-include cstdint"`, see this
[issue](https://github.com/google/sentencepiece/issues/1108).
### Will you release training code?
Some finetuning code can be found in the [kyutai-labs/moshi-finetune repo](https://github.com/kyutai-labs/moshi-finetune).
This code has not been adapted to the Speech-To-Text and Text-To-Speech models
yet, but it should be a good starting point.

View File

@ -237,6 +237,14 @@ echo "Hey, how are you?" | python scripts/tts_pytorch.py - -
python scripts/tts_pytorch.py text_to_say.txt audio_output.wav
```
The `tts_pytorch.py` script waits for all the text to be available before
starting the audio generation. A fully streaming implementation is available in
the `tts_pytorch_streaming.py` script, which can be used as follows:
```bash
echo "Hey, how are you?" | python scripts/tts_pytorch_streaming.py audio_output.wav
```
This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip.
If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the installation step
and just prefix the command above with `uvx --with moshi`.
@ -297,6 +305,10 @@ If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the install
and just prefix the command above with `uvx --with moshi-mlx`.
</details>
## FAQ
Checkout the [Frequently Asked Questions](FAQ.md) section before opening an issue.
## License
The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.

298
app/api_server.py Normal file
View File

@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""
OpenAI-Compatible Kyutai TTS API Server with Model Caching
Improved version that loads the model once and keeps it in memory
"""
import os
import io
import time
import asyncio
import subprocess
from pathlib import Path
from typing import Optional, Literal
import logging
import torch
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variables - loaded once at startup
tts_model = None
device = None
sample_rate = None
class SpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
app = FastAPI(
title="OpenAI-Compatible TTS API (Cached)",
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
version="2.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OUTPUT_DIR = Path("/app/api_output")
OUTPUT_DIR.mkdir(exist_ok=True)
def load_tts_model():
"""Load TTS model once at startup and keep in memory"""
global tts_model, device, sample_rate
if tts_model is not None:
logger.info("TTS model already loaded")
return
try:
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
# Import Kyutai TTS modules
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load the TTS model
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info,
n_q=32,
temp=0.6,
device=device
)
# Get sample rate
sample_rate = tts_model.mimi.sample_rate
logger.info(f"✅ TTS model loaded successfully!")
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
logger.info(f" Device: {device}")
logger.info(f" Sample Rate: {sample_rate}")
except Exception as e:
logger.error(f"❌ Failed to load TTS model: {e}")
raise
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
"""Generate audio using cached TTS model"""
global tts_model, device, sample_rate
if tts_model is None:
raise HTTPException(status_code=500, detail="TTS model not loaded")
try:
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
# Prepare the script (text input)
entries = tts_model.prepare_script([text], padding_between=1)
# Voice mapping for OpenAI compatibility
voice_mapping = {
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
}
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
try:
voice_path = tts_model.get_voice_path(selected_voice)
except:
# Fallback to default if voice not found
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
# Prepare condition attributes
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
# Generate audio
pcms = []
def on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
all_entries = [entries]
all_condition_attributes = [condition_attributes]
with tts_model.mimi.streaming(len(all_entries)):
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
# Concatenate all audio frames
if pcms:
import numpy as np
audio = np.concatenate(pcms, axis=-1)
# Apply speed adjustment if needed
if speed != 1.0:
# Simple speed adjustment by resampling
from scipy import signal
audio_length = len(audio)
new_length = int(audio_length / speed)
audio = signal.resample(audio, new_length)
# Convert to bytes
audio_bytes = io.BytesIO()
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
audio_bytes.seek(0)
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
return audio_bytes.read()
else:
raise Exception("No audio frames generated")
except Exception as e:
logger.error(f"❌ TTS generation error: {e}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
"""Convert WAV audio to requested format using ffmpeg"""
try:
if output_format == "wav":
return audio_wav_bytes
# Use ffmpeg to convert
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
result = subprocess.run(
cmd,
input=audio_wav_bytes,
capture_output=True,
check=True
)
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"Audio conversion failed: {e}")
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
@app.post("/v1/audio/speech")
async def create_speech(request: SpeechRequest):
"""
OpenAI-compatible audio speech endpoint
Uses cached TTS model for fast generation
"""
try:
start_time = time.time()
# Generate audio with cached model
audio_wav_bytes = generate_audio_fast(
text=request.input,
voice=request.voice,
speed=request.speed
)
# Convert to requested format
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
generation_time = time.time() - start_time
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
# Set appropriate content type
content_types = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
"pcm": "audio/pcm"
}
return Response(
content=audio_data,
media_type=content_types.get(request.response_format, "audio/wav"),
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Generation-Time": str(generation_time)
}
)
except Exception as e:
logger.error(f"Speech generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List available models (OpenAI-compatible)"""
return {
"object": "list",
"data": [
{
"id": "tts-1",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1",
"parent": None
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1-hd",
"parent": None
}
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint with model status"""
model_loaded = tts_model is not None
return {
"status": "healthy" if model_loaded else "loading",
"model_loaded": model_loaded,
"cuda_available": torch.cuda.is_available(),
"device": str(device) if device else None,
"service": "kyutai-tts-openai-compatible-cached"
}
@app.get("/reload-model")
async def reload_model():
"""Reload the TTS model (admin endpoint)"""
global tts_model
try:
tts_model = None
load_tts_model()
return {"status": "success", "message": "Model reloaded successfully"}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("🚀 Starting TTS API server with model caching...")
load_tts_model()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

67
app/dependency_check.py Normal file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Check if all Kyutai TTS dependencies are properly installed
"""
import sys
def check_dependencies():
print("🔍 Checking Kyutai TTS Dependencies")
print("=" * 40)
dependencies = [
"torch",
"numpy",
"einops",
"transformers",
"accelerate",
"soundfile",
"librosa",
"huggingface_hub",
"moshi",
"sphn"
]
missing = []
installed = []
for dep in dependencies:
try:
__import__(dep)
installed.append(dep)
print(f"{dep}")
except ImportError as e:
missing.append((dep, str(e)))
print(f"{dep}: {e}")
print(f"\n📊 Summary:")
print(f"✓ Installed: {len(installed)}")
print(f"✗ Missing: {len(missing)}")
if missing:
print(f"\n🔧 To fix missing dependencies:")
for dep, error in missing:
print(f"pip install {dep}")
print(f"\n🧪 Testing Kyutai TTS imports:")
try:
from moshi.models.loaders import CheckpointInfo
print("✓ CheckpointInfo import successful")
except Exception as e:
print(f"✗ CheckpointInfo import failed: {e}")
try:
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
print("✓ TTSModel imports successful")
except Exception as e:
print(f"✗ TTSModel imports failed: {e}")
return len(missing) == 0
if __name__ == "__main__":
success = check_dependencies()
if success:
print("\n🎉 All dependencies are installed correctly!")
else:
print("\n❌ Some dependencies are missing. Please install them first.")
sys.exit(1)

59
app/scripts/tts_runner.py Normal file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""
Kyutai TTS PyTorch Runner
Dockerized implementation for text-to-speech generation
"""
import sys
import os
import argparse
import torch
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description='Kyutai TTS PyTorch Runner')
parser.add_argument('input_file', help='Input text file or "-" for stdin')
parser.add_argument('output_file', help='Output audio file')
parser.add_argument('--model', default='kyutai/tts-1.6b-en_fr', help='TTS model to use')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
args = parser.parse_args()
print(f"Using device: {args.device}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Handle stdin input
if args.input_file == '-':
# Read from stdin and create temporary file
text = sys.stdin.read().strip()
temp_file = '/tmp/temp_input.txt'
with open(temp_file, 'w') as f:
f.write(text)
input_file = temp_file
else:
input_file = args.input_file
# Check if the original TTS script exists
tts_script = Path('/app/scripts/tts_pytorch.py')
if tts_script.exists():
print("Using original TTS script from Kyutai repository")
import subprocess
cmd = ['python', str(tts_script), input_file, args.output_file]
subprocess.run(cmd, check=True)
else:
print("Using moshi package for TTS generation")
import subprocess
cmd = [
'python', '-m', 'moshi.run_inference',
'--hf-repo', args.model,
input_file,
args.output_file
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Error: {result.stderr}")
sys.exit(1)
print(f"Audio generated: {args.output_file}")
if __name__ == '__main__':
main()
EOF

78
install.sh Normal file
View File

@ -0,0 +1,78 @@
# Set environment variables
export DEBIAN_FRONTEND=noninteractive
export PYTHONUNBUFFERED=1
export CUDA_VISIBLE_DEVICES=0
# Install system dependencies
apt-get update && apt-get install -y \
wget \
curl \
git \
build-essential \
libsndfile1 \
ffmpeg \
sox \
alsa-utils \
pulseaudio \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Create virtual environment
apt install python3.12-venv python3.12-dev
python3.12 -m venv ~/venv-tts-kyutai
source ~/venv-tts-kyutai/bin/activate
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Install PyTorch with CUDA support for Python 3.12
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install core dependencies
pip install --no-cache-dir \
numpy \
scipy \
librosa \
soundfile \
huggingface_hub \
einops \
transformers \
accelerate
# Install API dependencies
pip install --no-cache-dir \
fastapi \
uvicorn[standard] \
python-multipart \
pydantic
# Install moshi package with all dependencies (following Colab notebook)
pip install --no-cache-dir 'sphn<0.2'
pip install --no-cache-dir "moshi==0.2.8"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp /app/kyutai-repo/scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Start TTS-Server
python /app/api_server.py

View File

@ -0,0 +1,100 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.12",
# "numpy",
# "sentencepiece",
# "sounddevice",
# "sphn",
# ]
# ///
import argparse
import json
import mlx.core as mx
import mlx.nn as nn
import sentencepiece
import sphn
from huggingface_hub import hf_hub_download
from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("in_file", help="The file to transcribe.")
parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo")
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
args = parser.parse_args()
audio, _ = sphn.read(args.in_file, sample_rate=24000)
if args.hf_repo is None:
if args.vad:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
else:
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
lm_config = hf_hub_download(args.hf_repo, "config.json")
with open(lm_config, "r") as fobj:
lm_config = json.load(fobj)
mimi_weights = hf_hub_download(args.hf_repo, lm_config["mimi_name"])
moshi_name = lm_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_hub_download(args.hf_repo, moshi_name)
text_tokenizer = hf_hub_download(args.hf_repo, lm_config["tokenizer_name"])
lm_config = models.LmConfig.from_config_dict(lm_config)
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
if moshi_weights.endswith(".q4.safetensors"):
nn.quantize(model, bits=4, group_size=32)
elif moshi_weights.endswith(".q8.safetensors"):
nn.quantize(model, bits=8, group_size=64)
print(f"loading model weights from {moshi_weights}")
if args.hf_repo.endswith("-candle"):
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
else:
model.load_weights(moshi_weights, strict=True)
print(f"loading the text tokenizer from {text_tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(text_tokenizer) # type: ignore
print(f"loading the audio tokenizer {mimi_weights}")
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(32))
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
print("warming up the model")
model.warmup()
gen = models.LmGen(
model=model,
max_steps=args.max_steps,
text_sampler=utils.Sampler(top_k=25, temp=0),
audio_sampler=utils.Sampler(top_k=250, temp=0.8),
check=False,
)
print(f"starting inference {audio.shape}")
audio = mx.concat([mx.array(audio), mx.zeros((1, 48000))], axis=-1)
last_print_was_vad = False
for start_idx in range(0, audio.shape[-1] // 1920 * 1920, 1920):
block = audio[:, None, start_idx : start_idx + 1920]
other_audio_tokens = audio_tokenizer.encode_step(block).transpose(0, 2, 1)
if args.vad:
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
_text = None
if text_token not in (0, 3):
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
print()

View File

@ -4,7 +4,7 @@
# "julius",
# "librosa",
# "soundfile",
# "moshi",
# "moshi==0.2.11",
# ]
# ///
@ -20,8 +20,8 @@ import math
import julius
import moshi.models
import sphn
import time
import torch
import tqdm
@dataclasses.dataclass
@ -128,6 +128,9 @@ def tokens_to_timestamped_text(
def main(args):
if args.vad and args.hf_repo is None:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo,
moshi_weights=args.moshi_weight,
@ -171,14 +174,35 @@ def main(args):
itertools.repeat(silence_chunk, n_suffix_chunks),
)
start_time = time.time()
nchunks = 0
last_print_was_vad = False
with mimi.streaming(1), lm_gen.streaming(1):
for audio_chunk in tqdm.tqdm(chunks):
for audio_chunk in chunks:
nchunks += 1
audio_tokens = mimi.encode(audio_chunk)
text_tokens = lm_gen.step(audio_tokens)
if text_tokens is not None:
text_tokens_accum.append(text_tokens)
if args.vad:
text_tokens, vad_heads = lm_gen.step_with_extra_heads(audio_tokens)
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].cpu().item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_tokens = lm_gen.step(audio_tokens)
text_token = text_tokens[0, 0, 0].cpu().item()
if text_token not in (0, 3):
_text = tokenizer.id_to_piece(text_tokens[0, 0, 0].cpu().item()) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False
text_tokens_accum.append(text_tokens)
utterance_tokens = torch.concat(text_tokens_accum, dim=-1)
dt = time.time() - start_time
print(
f"\nprocessed {nchunks} chunks in {dt:.2f} seconds, steps per second: {nchunks / dt:.2f}"
)
timed_text = tokens_to_timestamped_text(
utterance_tokens,
tokenizer,
@ -209,6 +233,9 @@ if __name__ == "__main__":
parser.add_argument(
"--config-path", type=str, help="Path to a local config file.", default=None
)
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
parser.add_argument(
"--device",
type=str,

View File

@ -2,7 +2,7 @@
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx",
# "moshi_mlx==0.2.12",
# "numpy",
# "rustymimi",
# "sentencepiece",
@ -25,9 +25,17 @@ from moshi_mlx import models, utils
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
parser.add_argument("--hf-repo")
parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
)
args = parser.parse_args()
if args.hf_repo is None:
if args.vad:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
else:
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
lm_config = hf_hub_download(args.hf_repo, "config.json")
with open(lm_config, "r") as fobj:
lm_config = json.load(fobj)
@ -45,7 +53,10 @@ if __name__ == "__main__":
nn.quantize(model, bits=8, group_size=64)
print(f"loading model weights from {moshi_weights}")
model.load_weights(moshi_weights, strict=True)
if args.hf_repo.endswith("-candle"):
model.load_pytorch_weights(moshi_weights, lm_config, strict=True)
else:
model.load_weights(moshi_weights, strict=True)
print(f"loading the text tokenizer from {tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(tokenizer) # type: ignore
@ -71,6 +82,7 @@ if __name__ == "__main__":
block_queue.put(indata.copy())
print("recording audio from microphone, speak to get your words transcribed")
last_print_was_vad = False
with sd.InputStream(
channels=1,
dtype="float32",
@ -85,7 +97,15 @@ if __name__ == "__main__":
other_audio_tokens = mx.array(other_audio_tokens).transpose(0, 2, 1)[
:, :, :other_codebooks
]
text_token = gen.step(other_audio_tokens[0])
if args.vad:
text_token, vad_heads = gen.step_with_extra_heads(other_audio_tokens[0])
if vad_heads:
pr_vad = vad_heads[2][0, 0, 0].item()
if pr_vad > 0.5 and not last_print_was_vad:
print(" [end of turn detected]")
last_print_was_vad = True
else:
text_token = gen.step(other_audio_tokens[0])
text_token = text_token[0].item()
audio_tokens = gen.last_audio_tokens()
_text = None
@ -93,3 +113,4 @@ if __name__ == "__main__":
_text = text_tokenizer.id_to_piece(text_token) # type: ignore
_text = _text.replace("", " ")
print(_text, end="", flush=True)
last_print_was_vad = False

View File

@ -2,7 +2,7 @@
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.9",
# "moshi_mlx==0.2.12",
# "numpy",
# "sounddevice",
# ]
@ -36,7 +36,7 @@ def log(level: str, msg: str):
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
description="Run Kyutai TTS using the MLX implementation"
)
parser.add_argument("inp", type=str, help="Input file, use - for stdin")
parser.add_argument(
@ -76,6 +76,9 @@ def main():
moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)

View File

@ -0,0 +1,317 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "huggingface_hub",
# "moshi_mlx==0.2.12",
# "numpy",
# "sounddevice",
# ]
# ///
import argparse
from dataclasses import dataclass
import json
import queue
import sys
import time
import mlx.core as mx
import mlx.nn as nn
import numpy as np
import sentencepiece
import sounddevice as sd
import sphn
import typing as tp
from moshi_mlx import models
from moshi_mlx.models.generate import LmGen
from moshi_mlx.client_utils import make_log
from moshi_mlx.modules.conditioner import (
ConditionAttributes,
ConditionTensor,
dropout_all_conditions,
)
from moshi_mlx.utils.sampling import Sampler
from moshi_mlx.models.tts import (
Entry,
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
script_to_entries,
)
from moshi_mlx.utils.loaders import hf_get
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
multi_speaker = first_turn and model.multi_speaker
return script_to_entries(
model.tokenizer,
model.machine.token_ids,
model.mimi.frame_rate,
[script],
multi_speaker=multi_speaker,
padding_between=1,
)
def _make_null(
all_attributes: tp.Sequence[ConditionAttributes],
) -> list[ConditionAttributes]:
# When using CFG, returns the null conditions.
return dropout_all_conditions(all_attributes)
@dataclass
class TTSGen:
tts_model: TTSModel
attributes: tp.Sequence[ConditionAttributes]
on_frame: tp.Optional[tp.Callable[[mx.array], None]] = None
def __post_init__(self):
tts_model = self.tts_model
attributes = self.attributes
self.offset = 0
self.state = self.tts_model.machine.new_state([])
if tts_model.cfg_coef != 1.0:
if tts_model.valid_cfg_conditionings:
raise ValueError(
"This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes)
attributes = list(attributes) + nulled
assert tts_model.lm.condition_provider is not None
self.ct = None
self.cross_attention_src = None
for _attr in attributes:
for _key, _value in _attr.text.items():
_ct = tts_model.lm.condition_provider.condition_tensor(_key, _value)
if self.ct is None:
self.ct = _ct
else:
self.ct = ConditionTensor(self.ct.tensor + _ct.tensor)
for _key, _value in _attr.tensor.items():
_conditioner = tts_model.lm.condition_provider.conditioners[_key]
_ca_src = _conditioner.condition(_value)
if self.cross_attention_src is None:
self.cross_attention_src = _ca_src
else:
raise ValueError("multiple cross-attention conditioners")
def _on_audio_hook(audio_tokens):
delays = tts_model.lm.delays
for q in range(audio_tokens.shape[0]):
delay = delays[q]
if self.offset < delay + tts_model.delay_steps:
audio_tokens[q] = tts_model.machine.token_ids.zero
def _on_text_hook(text_tokens):
tokens = text_tokens.tolist()
out_tokens = []
for token in tokens:
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
out_tokens.append(out_token)
text_tokens[:] = mx.array(out_tokens, dtype=mx.int64)
self.lm_gen = LmGen(
tts_model.lm,
max_steps=tts_model.max_gen_length,
text_sampler=Sampler(temp=tts_model.temp),
audio_sampler=Sampler(temp=tts_model.temp),
cfg_coef=tts_model.cfg_coef,
on_text_hook=_on_text_hook,
on_audio_hook=_on_audio_hook,
# TODO(laurent):
# cfg_is_masked_until=cfg_is_masked_until,
# cfg_is_no_text=cfg_is_no_text,
)
def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step()
additional_steps = (
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
)
for _ in range(additional_steps):
self._step()
def process(self):
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
self._step()
def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = (
mx.ones((1, missing), dtype=mx.int64)
* self.tts_model.machine.token_ids.zero
)
self.lm_gen.step(
input_tokens, ct=self.ct, cross_attention_src=self.cross_attention_src
)
frame = self.lm_gen.last_audio_tokens()
self.offset += 1
if frame is not None:
if self.on_frame is not None:
self.on_frame(frame)
def append_entry(self, entry):
self.state.entries.append(entry)
def log(level: str, msg: str):
print(make_log(level, msg))
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the MLX implementation"
)
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice", default="expresso/ex03-ex01_happy_001_channel1_334s.wav"
)
parser.add_argument(
"--quantize",
type=int,
help="The quantization to be applied, e.g. 8 for 8 bits.",
)
args = parser.parse_args()
mx.random.seed(299792458)
log("info", "retrieving checkpoints")
raw_config = hf_get("config.json", args.hf_repo)
with open(hf_get(raw_config), "r") as fobj:
raw_config = json.load(fobj)
mimi_weights = hf_get(raw_config["mimi_name"], args.hf_repo)
moshi_name = raw_config.get("moshi_name", "model.safetensors")
moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16)
log("info", f"loading model weights from {moshi_weights}")
model.load_pytorch_weights(str(moshi_weights), lm_config, strict=True)
if args.quantize is not None:
log("info", f"quantizing model to {args.quantize} bits")
nn.quantize(model.depformer, bits=args.quantize)
for layer in model.transformer.layers:
nn.quantize(layer.self_attn, bits=args.quantize)
nn.quantize(layer.gating, bits=args.quantize)
log("info", f"loading the text tokenizer from {tokenizer}")
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer)) # type: ignore
log("info", f"loading the audio tokenizer {mimi_weights}")
generated_codebooks = lm_config.generated_codebooks
audio_tokenizer = models.mimi.Mimi(models.mimi_202407(generated_codebooks))
audio_tokenizer.load_pytorch_weights(str(mimi_weights), strict=True)
cfg_coef_conditioning = None
tts_model = TTSModel(
model,
audio_tokenizer,
text_tokenizer,
voice_repo=args.voice_repo,
temp=0.6,
cfg_coef=1,
max_padding=8,
initial_padding=2,
final_padding=2,
padding_bonus=0,
raw_config=raw_config,
)
if tts_model.valid_cfg_conditionings:
# Model was trained with CFG distillation.
cfg_coef_conditioning = tts_model.cfg_coef
tts_model.cfg_coef = 1.0
mimi = tts_model.mimi
log("info", "reading input from stdin")
if tts_model.multi_speaker:
voices = [tts_model.get_voice_path(args.voice)]
else:
voices = []
all_attributes = [
tts_model.make_condition_attributes(voices, cfg_coef_conditioning)
]
wav_frames = queue.Queue()
def _on_frame(frame):
if (frame == -1).any():
return
_pcm = tts_model.mimi.decode_step(frame[:, :, None])
_pcm = np.array(mx.clip(_pcm[0, 0], -1, 1))
wav_frames.put_nowait(_pcm)
gen = TTSGen(tts_model, all_attributes, on_frame=_on_frame)
def run():
log("info", "starting the inference loop")
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
if args.out == "-":
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = wav_frames.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
with sd.OutputStream(
samplerate=mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
):
run()
while True:
if wav_frames.qsize() == 0:
break
time.sleep(1)
else:
run()
frames = []
while True:
try:
frames.append(wav_frames.get_nowait())
except queue.Empty:
break
wav = np.concat(frames, -1)
sphn.write_wav(args.out, wav, mimi.sample_rate)
if __name__ == "__main__":
main()

View File

@ -1,7 +1,7 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "moshi==0.2.8",
# "moshi==0.2.11",
# "torch",
# "sphn",
# "sounddevice",
@ -68,13 +68,17 @@ def main():
# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
entries = tts_model.prepare_script([text], padding_between=1)
voice_path = tts_model.get_voice_path(args.voice)
if args.voice.endswith(".safetensors"):
voice_path = args.voice
else:
voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time.
# Also, if you are generating a dialog, you should have two voices in the list.
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
_frames_cnt = 0
if args.out == "-":
# Stream the audio to the speakers using sounddevice.
@ -83,9 +87,12 @@ def main():
pcms = queue.Queue()
def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
def audio_callback(outdata, _a, _b, _c):
try:
@ -101,14 +108,25 @@ def main():
callback=audio_callback,
):
with tts_model.mimi.streaming(1):
tts_model.generate([entries], [condition_attributes], on_frame=_on_frame)
tts_model.generate(
[entries], [condition_attributes], on_frame=_on_frame
)
time.sleep(3)
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else:
result = tts_model.generate([entries], [condition_attributes])
def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all():
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
result = tts_model.generate(
[entries], [condition_attributes], on_frame=_on_frame
)
with tts_model.mimi.streaming(1), torch.no_grad():
pcms = []
for frame in result.frames[tts_model.delay_steps :]:

View File

@ -0,0 +1,261 @@
# /// script
# requires-python = ">=3.12"
# dependencies = [
# "moshi==0.2.11",
# "torch",
# "sphn",
# "sounddevice",
# ]
# ///
import argparse
from dataclasses import dataclass
import sys
import numpy as np
import queue
import sphn
import time
import torch
import typing as tp
from moshi.models.loaders import CheckpointInfo
from moshi.conditioners import dropout_all_conditions
from moshi.models.lm import LMGen
from moshi.models.tts import (
Entry,
DEFAULT_DSM_TTS_REPO,
DEFAULT_DSM_TTS_VOICE_REPO,
TTSModel,
ConditionAttributes,
script_to_entries,
)
def prepare_script(model: TTSModel, script: str, first_turn: bool) -> list[Entry]:
multi_speaker = first_turn and model.multi_speaker
return script_to_entries(
model.tokenizer,
model.machine.token_ids,
model.mimi.frame_rate,
[script],
multi_speaker=multi_speaker,
padding_between=1,
)
def _make_null(
all_attributes: tp.Sequence[ConditionAttributes],
) -> list[ConditionAttributes]:
# When using CFG, returns the null conditions.
return dropout_all_conditions(all_attributes)
@dataclass
class TTSGen:
tts_model: TTSModel
attributes: tp.Sequence[ConditionAttributes]
on_frame: tp.Optional[tp.Callable[[torch.Tensor], None]] = None
def __post_init__(self):
tts_model = self.tts_model
attributes = self.attributes
self.offset = 0
self.state = self.tts_model.machine.new_state([])
if tts_model.cfg_coef != 1.0:
if tts_model.valid_cfg_conditionings:
raise ValueError(
"This model does not support direct CFG, but was trained with "
"CFG distillation. Pass instead `cfg_coef` to `make_condition_attributes`."
)
nulled = _make_null(attributes)
attributes = list(attributes) + nulled
assert tts_model.lm.condition_provider is not None
prepared = tts_model.lm.condition_provider.prepare(attributes)
condition_tensors = tts_model.lm.condition_provider(prepared)
def _on_text_logits_hook(text_logits):
if tts_model.padding_bonus:
text_logits[..., tts_model.machine.token_ids.pad] += (
tts_model.padding_bonus
)
return text_logits
def _on_audio_hook(audio_tokens):
audio_offset = tts_model.lm.audio_offset
delays = tts_model.lm.delays
for q in range(audio_tokens.shape[1]):
delay = delays[q + audio_offset]
if self.offset < delay + tts_model.delay_steps:
audio_tokens[:, q] = tts_model.machine.token_ids.zero
def _on_text_hook(text_tokens):
tokens = text_tokens.tolist()
out_tokens = []
for token in tokens:
out_token, _ = tts_model.machine.process(self.offset, self.state, token)
out_tokens.append(out_token)
text_tokens[:] = torch.tensor(
out_tokens, dtype=torch.long, device=text_tokens.device
)
tts_model.lm.dep_q = tts_model.n_q
self.lm_gen = LMGen(
tts_model.lm,
temp=tts_model.temp,
temp_text=tts_model.temp,
cfg_coef=tts_model.cfg_coef,
condition_tensors=condition_tensors,
on_text_logits_hook=_on_text_logits_hook,
on_text_hook=_on_text_hook,
on_audio_hook=_on_audio_hook,
cfg_is_masked_until=None,
cfg_is_no_text=True,
)
self.lm_gen.streaming_forever(1)
def process_last(self):
while len(self.state.entries) > 0 or self.state.end_step is not None:
self._step()
additional_steps = (
self.tts_model.delay_steps + max(self.tts_model.lm.delays) + 8
)
for _ in range(additional_steps):
self._step()
def process(self):
while len(self.state.entries) > self.tts_model.machine.second_stream_ahead:
self._step()
def _step(self):
missing = self.tts_model.lm.n_q - self.tts_model.lm.dep_q
input_tokens = torch.full(
(1, missing, 1),
self.tts_model.machine.token_ids.zero,
dtype=torch.long,
device=self.tts_model.lm.device,
)
frame = self.lm_gen.step(input_tokens)
self.offset += 1
if frame is not None:
if self.on_frame is not None:
self.on_frame(frame)
def append_entry(self, entry):
self.state.entries.append(entry)
@torch.no_grad()
def main():
parser = argparse.ArgumentParser(
description="Run Kyutai TTS using the PyTorch implementation"
)
parser.add_argument(
"out", type=str, help="Output file to generate, use - for playing the audio"
)
parser.add_argument(
"--hf-repo",
type=str,
default=DEFAULT_DSM_TTS_REPO,
help="HF repo in which to look for the pretrained models.",
)
parser.add_argument(
"--voice-repo",
default=DEFAULT_DSM_TTS_VOICE_REPO,
help="HF repo in which to look for pre-computed voice embeddings.",
)
parser.add_argument(
"--voice",
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args()
print("Loading model...")
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=args.device
)
if args.voice.endswith(".safetensors"):
voice_path = args.voice
else:
voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time.
# Also, if you are generating a dialog, you should have two voices in the list.
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
if args.out == "-":
# Stream the audio to the speakers using sounddevice.
import sounddevice as sd
pcms = queue.Queue()
def _on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
def audio_callback(outdata, _a, _b, _c):
try:
pcm_data = pcms.get(block=False)
outdata[:, 0] = pcm_data
except queue.Empty:
outdata[:] = 0
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
with sd.OutputStream(
samplerate=tts_model.mimi.sample_rate,
blocksize=1920,
channels=1,
callback=audio_callback,
) and tts_model.mimi.streaming(1):
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
while True:
if pcms.qsize() == 0:
break
time.sleep(1)
else:
pcms = []
def _on_frame(frame: torch.Tensor):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(np.clip(pcm[0, 0]))
gen = TTSGen(tts_model, [condition_attributes], on_frame=_on_frame)
with tts_model.mimi.streaming(1):
first_turn = True
for line in sys.stdin:
entries = prepare_script(tts_model, line.strip(), first_turn=first_turn)
first_turn = False
for entry in entries:
gen.append_entry(entry)
gen.process()
gen.process_last()
pcm = np.concatenate(pcms, axis=-1)
sphn.write_wav(args.out, pcm, tts_model.mimi.sample_rate)
if __name__ == "__main__":
main()

View File

@ -89,6 +89,45 @@ async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None])
print(f"Saved audio to {out}")
async def read_lines_from_stdin():
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
loop = asyncio.get_running_loop()
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
while True:
line = await reader.readline()
if not line:
break
yield line.decode().rstrip()
async def read_lines_from_file(path: str):
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def producer():
with open(path, "r", encoding="utf-8") as f:
for line in f:
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
await asyncio.to_thread(producer)
while True:
line = await queue.get()
if line is None:
break
yield line
async def get_lines(source: str):
if source == "-":
async for line in read_lines_from_stdin():
yield line
else:
async for line in read_lines_from_file(source):
yield line
async def websocket_client():
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
@ -113,25 +152,26 @@ async def websocket_client():
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
print(uri)
# TODO: stream the text instead of sending it all at once
if args.inp == "-":
if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
headers = {"kyutai-api-key": args.api_key}
async with websockets.connect(uri, additional_headers=headers) as websocket:
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
await websocket.send(msgpack.packb({"type": "Eos"}))
print("connected")
async def send_loop():
print("go send")
async for line in get_lines(args.inp):
for word in line.split():
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
await websocket.send(msgpack.packb({"type": "Eos"}))
output_queue = asyncio.Queue()
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
await asyncio.gather(receive_task, output_audio_task)
send_task = asyncio.create_task(send_loop())
await asyncio.gather(receive_task, output_audio_task, send_task)
if __name__ == "__main__":

View File

@ -9,9 +9,9 @@
"source": [
"# Fast install, might break in the future.\n",
"!pip install 'sphn<0.2'\n",
"!pip install --no-deps \"moshi==0.2.8\"\n",
"!pip install --no-deps \"moshi==0.2.11\"\n",
"# Slow install (will download torch and cuda), but future proof.\n",
"# !pip install \"moshi==0.2.8\""
"# !pip install \"moshi==0.2.11\""
]
},
{