298 lines
9.9 KiB
Python
298 lines
9.9 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
OpenAI-Compatible Kyutai TTS API Server with Model Caching
|
|
Improved version that loads the model once and keeps it in memory
|
|
"""
|
|
|
|
import os
|
|
import io
|
|
import time
|
|
import asyncio
|
|
import subprocess
|
|
from pathlib import Path
|
|
from typing import Optional, Literal
|
|
import logging
|
|
|
|
import torch
|
|
import soundfile as sf
|
|
from fastapi import FastAPI, HTTPException
|
|
from fastapi.responses import Response
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
import uvicorn
|
|
|
|
# Configure logging
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global model variables - loaded once at startup
|
|
tts_model = None
|
|
device = None
|
|
sample_rate = None
|
|
|
|
class SpeechRequest(BaseModel):
|
|
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
|
|
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
|
|
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
|
|
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
|
|
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
|
|
|
|
app = FastAPI(
|
|
title="OpenAI-Compatible TTS API (Cached)",
|
|
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
|
|
version="2.0.0"
|
|
)
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
OUTPUT_DIR = Path("/app/api_output")
|
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
|
|
|
def load_tts_model():
|
|
"""Load TTS model once at startup and keep in memory"""
|
|
global tts_model, device, sample_rate
|
|
|
|
if tts_model is not None:
|
|
logger.info("TTS model already loaded")
|
|
return
|
|
|
|
try:
|
|
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
|
|
|
|
# Import Kyutai TTS modules
|
|
from moshi.models.loaders import CheckpointInfo
|
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
|
|
|
|
# Set device
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
logger.info(f"Using device: {device}")
|
|
|
|
# Load the TTS model
|
|
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
|
|
|
|
tts_model = TTSModel.from_checkpoint_info(
|
|
checkpoint_info,
|
|
n_q=32,
|
|
temp=0.6,
|
|
device=device
|
|
)
|
|
|
|
# Get sample rate
|
|
sample_rate = tts_model.mimi.sample_rate
|
|
|
|
logger.info(f"✅ TTS model loaded successfully!")
|
|
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
|
|
logger.info(f" Device: {device}")
|
|
logger.info(f" Sample Rate: {sample_rate}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ Failed to load TTS model: {e}")
|
|
raise
|
|
|
|
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
|
|
"""Generate audio using cached TTS model"""
|
|
global tts_model, device, sample_rate
|
|
|
|
if tts_model is None:
|
|
raise HTTPException(status_code=500, detail="TTS model not loaded")
|
|
|
|
try:
|
|
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
|
|
|
|
# Prepare the script (text input)
|
|
entries = tts_model.prepare_script([text], padding_between=1)
|
|
|
|
# Voice mapping for OpenAI compatibility
|
|
voice_mapping = {
|
|
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
|
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
|
|
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
|
|
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
|
|
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
|
|
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
|
|
}
|
|
|
|
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
|
|
|
|
try:
|
|
voice_path = tts_model.get_voice_path(selected_voice)
|
|
except:
|
|
# Fallback to default if voice not found
|
|
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
|
|
|
# Prepare condition attributes
|
|
condition_attributes = tts_model.make_condition_attributes(
|
|
[voice_path], cfg_coef=2.0
|
|
)
|
|
|
|
# Generate audio
|
|
pcms = []
|
|
|
|
def on_frame(frame):
|
|
if (frame != -1).all():
|
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
|
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
|
|
|
|
all_entries = [entries]
|
|
all_condition_attributes = [condition_attributes]
|
|
|
|
with tts_model.mimi.streaming(len(all_entries)):
|
|
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
|
|
|
|
# Concatenate all audio frames
|
|
if pcms:
|
|
import numpy as np
|
|
audio = np.concatenate(pcms, axis=-1)
|
|
|
|
# Apply speed adjustment if needed
|
|
if speed != 1.0:
|
|
# Simple speed adjustment by resampling
|
|
from scipy import signal
|
|
audio_length = len(audio)
|
|
new_length = int(audio_length / speed)
|
|
audio = signal.resample(audio, new_length)
|
|
|
|
# Convert to bytes
|
|
audio_bytes = io.BytesIO()
|
|
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
|
|
audio_bytes.seek(0)
|
|
|
|
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
|
|
return audio_bytes.read()
|
|
else:
|
|
raise Exception("No audio frames generated")
|
|
|
|
except Exception as e:
|
|
logger.error(f"❌ TTS generation error: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
|
|
|
|
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
|
|
"""Convert WAV audio to requested format using ffmpeg"""
|
|
try:
|
|
if output_format == "wav":
|
|
return audio_wav_bytes
|
|
|
|
# Use ffmpeg to convert
|
|
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
|
|
|
|
result = subprocess.run(
|
|
cmd,
|
|
input=audio_wav_bytes,
|
|
capture_output=True,
|
|
check=True
|
|
)
|
|
return result.stdout
|
|
|
|
except subprocess.CalledProcessError as e:
|
|
logger.error(f"Audio conversion failed: {e}")
|
|
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
|
|
|
|
@app.post("/v1/audio/speech")
|
|
async def create_speech(request: SpeechRequest):
|
|
"""
|
|
OpenAI-compatible audio speech endpoint
|
|
Uses cached TTS model for fast generation
|
|
"""
|
|
try:
|
|
start_time = time.time()
|
|
|
|
# Generate audio with cached model
|
|
audio_wav_bytes = generate_audio_fast(
|
|
text=request.input,
|
|
voice=request.voice,
|
|
speed=request.speed
|
|
)
|
|
|
|
# Convert to requested format
|
|
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
|
|
|
|
generation_time = time.time() - start_time
|
|
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
|
|
|
|
# Set appropriate content type
|
|
content_types = {
|
|
"mp3": "audio/mpeg",
|
|
"opus": "audio/opus",
|
|
"aac": "audio/aac",
|
|
"flac": "audio/flac",
|
|
"wav": "audio/wav",
|
|
"pcm": "audio/pcm"
|
|
}
|
|
|
|
return Response(
|
|
content=audio_data,
|
|
media_type=content_types.get(request.response_format, "audio/wav"),
|
|
headers={
|
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
|
"X-Generation-Time": str(generation_time)
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Speech generation failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/v1/models")
|
|
async def list_models():
|
|
"""List available models (OpenAI-compatible)"""
|
|
return {
|
|
"object": "list",
|
|
"data": [
|
|
{
|
|
"id": "tts-1",
|
|
"object": "model",
|
|
"created": 1677610602,
|
|
"owned_by": "kyutai",
|
|
"permission": [],
|
|
"root": "tts-1",
|
|
"parent": None
|
|
},
|
|
{
|
|
"id": "tts-1-hd",
|
|
"object": "model",
|
|
"created": 1677610602,
|
|
"owned_by": "kyutai",
|
|
"permission": [],
|
|
"root": "tts-1-hd",
|
|
"parent": None
|
|
}
|
|
]
|
|
}
|
|
|
|
@app.get("/health")
|
|
async def health_check():
|
|
"""Health check endpoint with model status"""
|
|
model_loaded = tts_model is not None
|
|
return {
|
|
"status": "healthy" if model_loaded else "loading",
|
|
"model_loaded": model_loaded,
|
|
"cuda_available": torch.cuda.is_available(),
|
|
"device": str(device) if device else None,
|
|
"service": "kyutai-tts-openai-compatible-cached"
|
|
}
|
|
|
|
@app.get("/reload-model")
|
|
async def reload_model():
|
|
"""Reload the TTS model (admin endpoint)"""
|
|
global tts_model
|
|
try:
|
|
tts_model = None
|
|
load_tts_model()
|
|
return {"status": "success", "message": "Model reloaded successfully"}
|
|
except Exception as e:
|
|
return {"status": "error", "message": str(e)}
|
|
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
"""Load model on startup"""
|
|
logger.info("🚀 Starting TTS API server with model caching...")
|
|
load_tts_model()
|
|
|
|
if __name__ == "__main__":
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |