Compare commits

...

6 Commits

Author SHA1 Message Date
96042c9983 install.sh hinzugefügt
Some checks failed
precommit / Run precommit (push) Has been cancelled
2025-08-12 18:03:43 +00:00
7e92c6f28a app/scripts/tts_runner.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:59:20 +00:00
8e51c6eab9 app/dependency_check.py aktualisiert
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:58:17 +00:00
34e58b8055 dependency_check.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:57:33 +00:00
f23d1be027 app/api_server.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:56:32 +00:00
Laurent Mazare
cf97f8d863
Workaround for the mlx kv-cache bug. (#108) 2025-08-04 16:37:00 +02:00
6 changed files with 508 additions and 0 deletions

298
app/api_server.py Normal file
View File

@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""
OpenAI-Compatible Kyutai TTS API Server with Model Caching
Improved version that loads the model once and keeps it in memory
"""
import os
import io
import time
import asyncio
import subprocess
from pathlib import Path
from typing import Optional, Literal
import logging
import torch
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variables - loaded once at startup
tts_model = None
device = None
sample_rate = None
class SpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
app = FastAPI(
title="OpenAI-Compatible TTS API (Cached)",
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
version="2.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OUTPUT_DIR = Path("/app/api_output")
OUTPUT_DIR.mkdir(exist_ok=True)
def load_tts_model():
"""Load TTS model once at startup and keep in memory"""
global tts_model, device, sample_rate
if tts_model is not None:
logger.info("TTS model already loaded")
return
try:
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
# Import Kyutai TTS modules
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load the TTS model
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info,
n_q=32,
temp=0.6,
device=device
)
# Get sample rate
sample_rate = tts_model.mimi.sample_rate
logger.info(f"✅ TTS model loaded successfully!")
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
logger.info(f" Device: {device}")
logger.info(f" Sample Rate: {sample_rate}")
except Exception as e:
logger.error(f"❌ Failed to load TTS model: {e}")
raise
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
"""Generate audio using cached TTS model"""
global tts_model, device, sample_rate
if tts_model is None:
raise HTTPException(status_code=500, detail="TTS model not loaded")
try:
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
# Prepare the script (text input)
entries = tts_model.prepare_script([text], padding_between=1)
# Voice mapping for OpenAI compatibility
voice_mapping = {
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
}
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
try:
voice_path = tts_model.get_voice_path(selected_voice)
except:
# Fallback to default if voice not found
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
# Prepare condition attributes
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
# Generate audio
pcms = []
def on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
all_entries = [entries]
all_condition_attributes = [condition_attributes]
with tts_model.mimi.streaming(len(all_entries)):
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
# Concatenate all audio frames
if pcms:
import numpy as np
audio = np.concatenate(pcms, axis=-1)
# Apply speed adjustment if needed
if speed != 1.0:
# Simple speed adjustment by resampling
from scipy import signal
audio_length = len(audio)
new_length = int(audio_length / speed)
audio = signal.resample(audio, new_length)
# Convert to bytes
audio_bytes = io.BytesIO()
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
audio_bytes.seek(0)
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
return audio_bytes.read()
else:
raise Exception("No audio frames generated")
except Exception as e:
logger.error(f"❌ TTS generation error: {e}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
"""Convert WAV audio to requested format using ffmpeg"""
try:
if output_format == "wav":
return audio_wav_bytes
# Use ffmpeg to convert
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
result = subprocess.run(
cmd,
input=audio_wav_bytes,
capture_output=True,
check=True
)
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"Audio conversion failed: {e}")
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
@app.post("/v1/audio/speech")
async def create_speech(request: SpeechRequest):
"""
OpenAI-compatible audio speech endpoint
Uses cached TTS model for fast generation
"""
try:
start_time = time.time()
# Generate audio with cached model
audio_wav_bytes = generate_audio_fast(
text=request.input,
voice=request.voice,
speed=request.speed
)
# Convert to requested format
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
generation_time = time.time() - start_time
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
# Set appropriate content type
content_types = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
"pcm": "audio/pcm"
}
return Response(
content=audio_data,
media_type=content_types.get(request.response_format, "audio/wav"),
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Generation-Time": str(generation_time)
}
)
except Exception as e:
logger.error(f"Speech generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List available models (OpenAI-compatible)"""
return {
"object": "list",
"data": [
{
"id": "tts-1",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1",
"parent": None
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1-hd",
"parent": None
}
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint with model status"""
model_loaded = tts_model is not None
return {
"status": "healthy" if model_loaded else "loading",
"model_loaded": model_loaded,
"cuda_available": torch.cuda.is_available(),
"device": str(device) if device else None,
"service": "kyutai-tts-openai-compatible-cached"
}
@app.get("/reload-model")
async def reload_model():
"""Reload the TTS model (admin endpoint)"""
global tts_model
try:
tts_model = None
load_tts_model()
return {"status": "success", "message": "Model reloaded successfully"}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("🚀 Starting TTS API server with model caching...")
load_tts_model()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

67
app/dependency_check.py Normal file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Check if all Kyutai TTS dependencies are properly installed
"""
import sys
def check_dependencies():
print("🔍 Checking Kyutai TTS Dependencies")
print("=" * 40)
dependencies = [
"torch",
"numpy",
"einops",
"transformers",
"accelerate",
"soundfile",
"librosa",
"huggingface_hub",
"moshi",
"sphn"
]
missing = []
installed = []
for dep in dependencies:
try:
__import__(dep)
installed.append(dep)
print(f"{dep}")
except ImportError as e:
missing.append((dep, str(e)))
print(f"{dep}: {e}")
print(f"\n📊 Summary:")
print(f"✓ Installed: {len(installed)}")
print(f"✗ Missing: {len(missing)}")
if missing:
print(f"\n🔧 To fix missing dependencies:")
for dep, error in missing:
print(f"pip install {dep}")
print(f"\n🧪 Testing Kyutai TTS imports:")
try:
from moshi.models.loaders import CheckpointInfo
print("✓ CheckpointInfo import successful")
except Exception as e:
print(f"✗ CheckpointInfo import failed: {e}")
try:
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
print("✓ TTSModel imports successful")
except Exception as e:
print(f"✗ TTSModel imports failed: {e}")
return len(missing) == 0
if __name__ == "__main__":
success = check_dependencies()
if success:
print("\n🎉 All dependencies are installed correctly!")
else:
print("\n❌ Some dependencies are missing. Please install them first.")
sys.exit(1)

59
app/scripts/tts_runner.py Normal file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""
Kyutai TTS PyTorch Runner
Dockerized implementation for text-to-speech generation
"""
import sys
import os
import argparse
import torch
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description='Kyutai TTS PyTorch Runner')
parser.add_argument('input_file', help='Input text file or "-" for stdin')
parser.add_argument('output_file', help='Output audio file')
parser.add_argument('--model', default='kyutai/tts-1.6b-en_fr', help='TTS model to use')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
args = parser.parse_args()
print(f"Using device: {args.device}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Handle stdin input
if args.input_file == '-':
# Read from stdin and create temporary file
text = sys.stdin.read().strip()
temp_file = '/tmp/temp_input.txt'
with open(temp_file, 'w') as f:
f.write(text)
input_file = temp_file
else:
input_file = args.input_file
# Check if the original TTS script exists
tts_script = Path('/app/scripts/tts_pytorch.py')
if tts_script.exists():
print("Using original TTS script from Kyutai repository")
import subprocess
cmd = ['python', str(tts_script), input_file, args.output_file]
subprocess.run(cmd, check=True)
else:
print("Using moshi package for TTS generation")
import subprocess
cmd = [
'python', '-m', 'moshi.run_inference',
'--hf-repo', args.model,
input_file,
args.output_file
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Error: {result.stderr}")
sys.exit(1)
print(f"Audio generated: {args.output_file}")
if __name__ == '__main__':
main()
EOF

78
install.sh Normal file
View File

@ -0,0 +1,78 @@
# Set environment variables
export DEBIAN_FRONTEND=noninteractive
export PYTHONUNBUFFERED=1
export CUDA_VISIBLE_DEVICES=0
# Install system dependencies
apt-get update && apt-get install -y \
wget \
curl \
git \
build-essential \
libsndfile1 \
ffmpeg \
sox \
alsa-utils \
pulseaudio \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Create virtual environment
apt install python3.12-venv python3.12-dev
python3.12 -m venv ~/venv-tts-kyutai
source ~/venv-tts-kyutai/bin/activate
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Install PyTorch with CUDA support for Python 3.12
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install core dependencies
pip install --no-cache-dir \
numpy \
scipy \
librosa \
soundfile \
huggingface_hub \
einops \
transformers \
accelerate
# Install API dependencies
pip install --no-cache-dir \
fastapi \
uvicorn[standard] \
python-multipart \
pydantic
# Install moshi package with all dependencies (following Colab notebook)
pip install --no-cache-dir 'sphn<0.2'
pip install --no-cache-dir "moshi==0.2.8"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp /app/kyutai-repo/scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Start TTS-Server
python /app/api_server.py

View File

@ -76,6 +76,9 @@ def main():
moshi_weights = hf_get(moshi_name, args.hf_repo) moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config) lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config) model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16) model.set_dtype(mx.bfloat16)

View File

@ -205,6 +205,9 @@ def main():
moshi_weights = hf_get(moshi_name, args.hf_repo) moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config) lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config) model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16) model.set_dtype(mx.bfloat16)