app/api_server.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run

This commit is contained in:
tipi 2025-08-12 17:56:32 +00:00
parent cf97f8d863
commit f23d1be027

298
app/api_server.py Normal file
View File

@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""
OpenAI-Compatible Kyutai TTS API Server with Model Caching
Improved version that loads the model once and keeps it in memory
"""
import os
import io
import time
import asyncio
import subprocess
from pathlib import Path
from typing import Optional, Literal
import logging
import torch
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variables - loaded once at startup
tts_model = None
device = None
sample_rate = None
class SpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
app = FastAPI(
title="OpenAI-Compatible TTS API (Cached)",
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
version="2.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OUTPUT_DIR = Path("/app/api_output")
OUTPUT_DIR.mkdir(exist_ok=True)
def load_tts_model():
"""Load TTS model once at startup and keep in memory"""
global tts_model, device, sample_rate
if tts_model is not None:
logger.info("TTS model already loaded")
return
try:
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
# Import Kyutai TTS modules
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load the TTS model
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info,
n_q=32,
temp=0.6,
device=device
)
# Get sample rate
sample_rate = tts_model.mimi.sample_rate
logger.info(f"✅ TTS model loaded successfully!")
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
logger.info(f" Device: {device}")
logger.info(f" Sample Rate: {sample_rate}")
except Exception as e:
logger.error(f"❌ Failed to load TTS model: {e}")
raise
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
"""Generate audio using cached TTS model"""
global tts_model, device, sample_rate
if tts_model is None:
raise HTTPException(status_code=500, detail="TTS model not loaded")
try:
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
# Prepare the script (text input)
entries = tts_model.prepare_script([text], padding_between=1)
# Voice mapping for OpenAI compatibility
voice_mapping = {
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
}
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
try:
voice_path = tts_model.get_voice_path(selected_voice)
except:
# Fallback to default if voice not found
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
# Prepare condition attributes
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
# Generate audio
pcms = []
def on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
all_entries = [entries]
all_condition_attributes = [condition_attributes]
with tts_model.mimi.streaming(len(all_entries)):
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
# Concatenate all audio frames
if pcms:
import numpy as np
audio = np.concatenate(pcms, axis=-1)
# Apply speed adjustment if needed
if speed != 1.0:
# Simple speed adjustment by resampling
from scipy import signal
audio_length = len(audio)
new_length = int(audio_length / speed)
audio = signal.resample(audio, new_length)
# Convert to bytes
audio_bytes = io.BytesIO()
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
audio_bytes.seek(0)
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
return audio_bytes.read()
else:
raise Exception("No audio frames generated")
except Exception as e:
logger.error(f"❌ TTS generation error: {e}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
"""Convert WAV audio to requested format using ffmpeg"""
try:
if output_format == "wav":
return audio_wav_bytes
# Use ffmpeg to convert
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
result = subprocess.run(
cmd,
input=audio_wav_bytes,
capture_output=True,
check=True
)
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"Audio conversion failed: {e}")
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
@app.post("/v1/audio/speech")
async def create_speech(request: SpeechRequest):
"""
OpenAI-compatible audio speech endpoint
Uses cached TTS model for fast generation
"""
try:
start_time = time.time()
# Generate audio with cached model
audio_wav_bytes = generate_audio_fast(
text=request.input,
voice=request.voice,
speed=request.speed
)
# Convert to requested format
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
generation_time = time.time() - start_time
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
# Set appropriate content type
content_types = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
"pcm": "audio/pcm"
}
return Response(
content=audio_data,
media_type=content_types.get(request.response_format, "audio/wav"),
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Generation-Time": str(generation_time)
}
)
except Exception as e:
logger.error(f"Speech generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List available models (OpenAI-compatible)"""
return {
"object": "list",
"data": [
{
"id": "tts-1",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1",
"parent": None
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1-hd",
"parent": None
}
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint with model status"""
model_loaded = tts_model is not None
return {
"status": "healthy" if model_loaded else "loading",
"model_loaded": model_loaded,
"cuda_available": torch.cuda.is_available(),
"device": str(device) if device else None,
"service": "kyutai-tts-openai-compatible-cached"
}
@app.get("/reload-model")
async def reload_model():
"""Reload the TTS model (admin endpoint)"""
global tts_model
try:
tts_model = None
load_tts_model()
return {"status": "success", "message": "Model reloaded successfully"}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("🚀 Starting TTS API server with model caching...")
load_tts_model()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)