From f23d1be027ae25866dc254366a9b87eb38e4cf0c Mon Sep 17 00:00:00 2001 From: tipi Date: Tue, 12 Aug 2025 17:56:32 +0000 Subject: [PATCH] =?UTF-8?q?app/api=5Fserver.py=20hinzugef=C3=BCgt?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/api_server.py | 298 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 298 insertions(+) create mode 100644 app/api_server.py diff --git a/app/api_server.py b/app/api_server.py new file mode 100644 index 0000000..7f3010d --- /dev/null +++ b/app/api_server.py @@ -0,0 +1,298 @@ +#!/usr/bin/env python3 +""" +OpenAI-Compatible Kyutai TTS API Server with Model Caching +Improved version that loads the model once and keeps it in memory +""" + +import os +import io +import time +import asyncio +import subprocess +from pathlib import Path +from typing import Optional, Literal +import logging + +import torch +import soundfile as sf +from fastapi import FastAPI, HTTPException +from fastapi.responses import Response +from fastapi.middleware.cors import CORSMiddleware +from pydantic import BaseModel, Field +import uvicorn + +# Configure logging +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Global model variables - loaded once at startup +tts_model = None +device = None +sample_rate = None + +class SpeechRequest(BaseModel): + model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use") + input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for") + voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use") + response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format") + speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio") + +app = FastAPI( + title="OpenAI-Compatible TTS API (Cached)", + description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching", + version="2.0.0" +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + +OUTPUT_DIR = Path("/app/api_output") +OUTPUT_DIR.mkdir(exist_ok=True) + +def load_tts_model(): + """Load TTS model once at startup and keep in memory""" + global tts_model, device, sample_rate + + if tts_model is not None: + logger.info("TTS model already loaded") + return + + try: + logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...") + + # Import Kyutai TTS modules + from moshi.models.loaders import CheckpointInfo + from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel + + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # Load the TTS model + checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO) + + tts_model = TTSModel.from_checkpoint_info( + checkpoint_info, + n_q=32, + temp=0.6, + device=device + ) + + # Get sample rate + sample_rate = tts_model.mimi.sample_rate + + logger.info(f"✅ TTS model loaded successfully!") + logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}") + logger.info(f" Device: {device}") + logger.info(f" Sample Rate: {sample_rate}") + + except Exception as e: + logger.error(f"❌ Failed to load TTS model: {e}") + raise + +def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes: + """Generate audio using cached TTS model""" + global tts_model, device, sample_rate + + if tts_model is None: + raise HTTPException(status_code=500, detail="TTS model not loaded") + + try: + logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'") + + # Prepare the script (text input) + entries = tts_model.prepare_script([text], padding_between=1) + + # Voice mapping for OpenAI compatibility + voice_mapping = { + "alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav", + "echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav", + "fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav", + "onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav", + "nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav", + "shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav" + } + + selected_voice = voice_mapping.get(voice, voice_mapping["alloy"]) + + try: + voice_path = tts_model.get_voice_path(selected_voice) + except: + # Fallback to default if voice not found + voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav") + + # Prepare condition attributes + condition_attributes = tts_model.make_condition_attributes( + [voice_path], cfg_coef=2.0 + ) + + # Generate audio + pcms = [] + + def on_frame(frame): + if (frame != -1).all(): + pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() + pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy()) + + all_entries = [entries] + all_condition_attributes = [condition_attributes] + + with tts_model.mimi.streaming(len(all_entries)): + result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame) + + # Concatenate all audio frames + if pcms: + import numpy as np + audio = np.concatenate(pcms, axis=-1) + + # Apply speed adjustment if needed + if speed != 1.0: + # Simple speed adjustment by resampling + from scipy import signal + audio_length = len(audio) + new_length = int(audio_length / speed) + audio = signal.resample(audio, new_length) + + # Convert to bytes + audio_bytes = io.BytesIO() + sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV') + audio_bytes.seek(0) + + logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)") + return audio_bytes.read() + else: + raise Exception("No audio frames generated") + + except Exception as e: + logger.error(f"❌ TTS generation error: {e}") + raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}") + +def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes: + """Convert WAV audio to requested format using ffmpeg""" + try: + if output_format == "wav": + return audio_wav_bytes + + # Use ffmpeg to convert + cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"] + + result = subprocess.run( + cmd, + input=audio_wav_bytes, + capture_output=True, + check=True + ) + return result.stdout + + except subprocess.CalledProcessError as e: + logger.error(f"Audio conversion failed: {e}") + raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}") + +@app.post("/v1/audio/speech") +async def create_speech(request: SpeechRequest): + """ + OpenAI-compatible audio speech endpoint + Uses cached TTS model for fast generation + """ + try: + start_time = time.time() + + # Generate audio with cached model + audio_wav_bytes = generate_audio_fast( + text=request.input, + voice=request.voice, + speed=request.speed + ) + + # Convert to requested format + audio_data = convert_audio_format(audio_wav_bytes, request.response_format) + + generation_time = time.time() - start_time + logger.info(f"⚡ Total generation time: {generation_time:.2f}s") + + # Set appropriate content type + content_types = { + "mp3": "audio/mpeg", + "opus": "audio/opus", + "aac": "audio/aac", + "flac": "audio/flac", + "wav": "audio/wav", + "pcm": "audio/pcm" + } + + return Response( + content=audio_data, + media_type=content_types.get(request.response_format, "audio/wav"), + headers={ + "Content-Disposition": f"attachment; filename=speech.{request.response_format}", + "X-Generation-Time": str(generation_time) + } + ) + + except Exception as e: + logger.error(f"Speech generation failed: {e}") + raise HTTPException(status_code=500, detail=str(e)) + +@app.get("/v1/models") +async def list_models(): + """List available models (OpenAI-compatible)""" + return { + "object": "list", + "data": [ + { + "id": "tts-1", + "object": "model", + "created": 1677610602, + "owned_by": "kyutai", + "permission": [], + "root": "tts-1", + "parent": None + }, + { + "id": "tts-1-hd", + "object": "model", + "created": 1677610602, + "owned_by": "kyutai", + "permission": [], + "root": "tts-1-hd", + "parent": None + } + ] + } + +@app.get("/health") +async def health_check(): + """Health check endpoint with model status""" + model_loaded = tts_model is not None + return { + "status": "healthy" if model_loaded else "loading", + "model_loaded": model_loaded, + "cuda_available": torch.cuda.is_available(), + "device": str(device) if device else None, + "service": "kyutai-tts-openai-compatible-cached" + } + +@app.get("/reload-model") +async def reload_model(): + """Reload the TTS model (admin endpoint)""" + global tts_model + try: + tts_model = None + load_tts_model() + return {"status": "success", "message": "Model reloaded successfully"} + except Exception as e: + return {"status": "error", "message": str(e)} + +@app.on_event("startup") +async def startup_event(): + """Load model on startup""" + logger.info("🚀 Starting TTS API server with model caching...") + load_tts_model() + +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=8000) \ No newline at end of file