Compare commits

...

10 Commits

Author SHA1 Message Date
96042c9983 install.sh hinzugefügt
Some checks failed
precommit / Run precommit (push) Has been cancelled
2025-08-12 18:03:43 +00:00
7e92c6f28a app/scripts/tts_runner.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:59:20 +00:00
8e51c6eab9 app/dependency_check.py aktualisiert
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:58:17 +00:00
34e58b8055 dependency_check.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:57:33 +00:00
f23d1be027 app/api_server.py hinzugefügt
Some checks are pending
precommit / Run precommit (push) Waiting to run
2025-08-12 17:56:32 +00:00
Laurent Mazare
cf97f8d863
Workaround for the mlx kv-cache bug. (#108) 2025-08-04 16:37:00 +02:00
Laurent Mazare
09468c239a
Print the duration of the audio generated so far. (#107) 2025-08-04 09:24:31 +02:00
Laurent Mazare
07729ed47e
Use the proper repos when vad is on. (#103) 2025-08-01 15:55:49 +02:00
Laurent Mazare
af2283de3f
Use a streaming input in the rust example. (#102)
* Use a streaming input in the rust example.

* Formatting.

* Another formatting tweak.
2025-07-31 17:41:57 +02:00
Laurent Mazare
7dc926d50c
Allow for using local voices in the pytorch examples. (#100) 2025-07-31 12:48:05 +02:00
12 changed files with 594 additions and 14 deletions

298
app/api_server.py Normal file
View File

@ -0,0 +1,298 @@
#!/usr/bin/env python3
"""
OpenAI-Compatible Kyutai TTS API Server with Model Caching
Improved version that loads the model once and keeps it in memory
"""
import os
import io
import time
import asyncio
import subprocess
from pathlib import Path
from typing import Optional, Literal
import logging
import torch
import soundfile as sf
from fastapi import FastAPI, HTTPException
from fastapi.responses import Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variables - loaded once at startup
tts_model = None
device = None
sample_rate = None
class SpeechRequest(BaseModel):
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
app = FastAPI(
title="OpenAI-Compatible TTS API (Cached)",
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
version="2.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
OUTPUT_DIR = Path("/app/api_output")
OUTPUT_DIR.mkdir(exist_ok=True)
def load_tts_model():
"""Load TTS model once at startup and keep in memory"""
global tts_model, device, sample_rate
if tts_model is not None:
logger.info("TTS model already loaded")
return
try:
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
# Import Kyutai TTS modules
from moshi.models.loaders import CheckpointInfo
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
# Load the TTS model
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
tts_model = TTSModel.from_checkpoint_info(
checkpoint_info,
n_q=32,
temp=0.6,
device=device
)
# Get sample rate
sample_rate = tts_model.mimi.sample_rate
logger.info(f"✅ TTS model loaded successfully!")
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
logger.info(f" Device: {device}")
logger.info(f" Sample Rate: {sample_rate}")
except Exception as e:
logger.error(f"❌ Failed to load TTS model: {e}")
raise
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
"""Generate audio using cached TTS model"""
global tts_model, device, sample_rate
if tts_model is None:
raise HTTPException(status_code=500, detail="TTS model not loaded")
try:
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
# Prepare the script (text input)
entries = tts_model.prepare_script([text], padding_between=1)
# Voice mapping for OpenAI compatibility
voice_mapping = {
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
}
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
try:
voice_path = tts_model.get_voice_path(selected_voice)
except:
# Fallback to default if voice not found
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
# Prepare condition attributes
condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0
)
# Generate audio
pcms = []
def on_frame(frame):
if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
all_entries = [entries]
all_condition_attributes = [condition_attributes]
with tts_model.mimi.streaming(len(all_entries)):
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
# Concatenate all audio frames
if pcms:
import numpy as np
audio = np.concatenate(pcms, axis=-1)
# Apply speed adjustment if needed
if speed != 1.0:
# Simple speed adjustment by resampling
from scipy import signal
audio_length = len(audio)
new_length = int(audio_length / speed)
audio = signal.resample(audio, new_length)
# Convert to bytes
audio_bytes = io.BytesIO()
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
audio_bytes.seek(0)
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
return audio_bytes.read()
else:
raise Exception("No audio frames generated")
except Exception as e:
logger.error(f"❌ TTS generation error: {e}")
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
"""Convert WAV audio to requested format using ffmpeg"""
try:
if output_format == "wav":
return audio_wav_bytes
# Use ffmpeg to convert
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
result = subprocess.run(
cmd,
input=audio_wav_bytes,
capture_output=True,
check=True
)
return result.stdout
except subprocess.CalledProcessError as e:
logger.error(f"Audio conversion failed: {e}")
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
@app.post("/v1/audio/speech")
async def create_speech(request: SpeechRequest):
"""
OpenAI-compatible audio speech endpoint
Uses cached TTS model for fast generation
"""
try:
start_time = time.time()
# Generate audio with cached model
audio_wav_bytes = generate_audio_fast(
text=request.input,
voice=request.voice,
speed=request.speed
)
# Convert to requested format
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
generation_time = time.time() - start_time
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
# Set appropriate content type
content_types = {
"mp3": "audio/mpeg",
"opus": "audio/opus",
"aac": "audio/aac",
"flac": "audio/flac",
"wav": "audio/wav",
"pcm": "audio/pcm"
}
return Response(
content=audio_data,
media_type=content_types.get(request.response_format, "audio/wav"),
headers={
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
"X-Generation-Time": str(generation_time)
}
)
except Exception as e:
logger.error(f"Speech generation failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/v1/models")
async def list_models():
"""List available models (OpenAI-compatible)"""
return {
"object": "list",
"data": [
{
"id": "tts-1",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1",
"parent": None
},
{
"id": "tts-1-hd",
"object": "model",
"created": 1677610602,
"owned_by": "kyutai",
"permission": [],
"root": "tts-1-hd",
"parent": None
}
]
}
@app.get("/health")
async def health_check():
"""Health check endpoint with model status"""
model_loaded = tts_model is not None
return {
"status": "healthy" if model_loaded else "loading",
"model_loaded": model_loaded,
"cuda_available": torch.cuda.is_available(),
"device": str(device) if device else None,
"service": "kyutai-tts-openai-compatible-cached"
}
@app.get("/reload-model")
async def reload_model():
"""Reload the TTS model (admin endpoint)"""
global tts_model
try:
tts_model = None
load_tts_model()
return {"status": "success", "message": "Model reloaded successfully"}
except Exception as e:
return {"status": "error", "message": str(e)}
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
logger.info("🚀 Starting TTS API server with model caching...")
load_tts_model()
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)

67
app/dependency_check.py Normal file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
"""
Check if all Kyutai TTS dependencies are properly installed
"""
import sys
def check_dependencies():
print("🔍 Checking Kyutai TTS Dependencies")
print("=" * 40)
dependencies = [
"torch",
"numpy",
"einops",
"transformers",
"accelerate",
"soundfile",
"librosa",
"huggingface_hub",
"moshi",
"sphn"
]
missing = []
installed = []
for dep in dependencies:
try:
__import__(dep)
installed.append(dep)
print(f"{dep}")
except ImportError as e:
missing.append((dep, str(e)))
print(f"{dep}: {e}")
print(f"\n📊 Summary:")
print(f"✓ Installed: {len(installed)}")
print(f"✗ Missing: {len(missing)}")
if missing:
print(f"\n🔧 To fix missing dependencies:")
for dep, error in missing:
print(f"pip install {dep}")
print(f"\n🧪 Testing Kyutai TTS imports:")
try:
from moshi.models.loaders import CheckpointInfo
print("✓ CheckpointInfo import successful")
except Exception as e:
print(f"✗ CheckpointInfo import failed: {e}")
try:
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
print("✓ TTSModel imports successful")
except Exception as e:
print(f"✗ TTSModel imports failed: {e}")
return len(missing) == 0
if __name__ == "__main__":
success = check_dependencies()
if success:
print("\n🎉 All dependencies are installed correctly!")
else:
print("\n❌ Some dependencies are missing. Please install them first.")
sys.exit(1)

59
app/scripts/tts_runner.py Normal file
View File

@ -0,0 +1,59 @@
#!/usr/bin/env python3
"""
Kyutai TTS PyTorch Runner
Dockerized implementation for text-to-speech generation
"""
import sys
import os
import argparse
import torch
from pathlib import Path
def main():
parser = argparse.ArgumentParser(description='Kyutai TTS PyTorch Runner')
parser.add_argument('input_file', help='Input text file or "-" for stdin')
parser.add_argument('output_file', help='Output audio file')
parser.add_argument('--model', default='kyutai/tts-1.6b-en_fr', help='TTS model to use')
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
args = parser.parse_args()
print(f"Using device: {args.device}")
print(f"CUDA available: {torch.cuda.is_available()}")
# Handle stdin input
if args.input_file == '-':
# Read from stdin and create temporary file
text = sys.stdin.read().strip()
temp_file = '/tmp/temp_input.txt'
with open(temp_file, 'w') as f:
f.write(text)
input_file = temp_file
else:
input_file = args.input_file
# Check if the original TTS script exists
tts_script = Path('/app/scripts/tts_pytorch.py')
if tts_script.exists():
print("Using original TTS script from Kyutai repository")
import subprocess
cmd = ['python', str(tts_script), input_file, args.output_file]
subprocess.run(cmd, check=True)
else:
print("Using moshi package for TTS generation")
import subprocess
cmd = [
'python', '-m', 'moshi.run_inference',
'--hf-repo', args.model,
input_file,
args.output_file
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
print(f"Error: {result.stderr}")
sys.exit(1)
print(f"Audio generated: {args.output_file}")
if __name__ == '__main__':
main()
EOF

78
install.sh Normal file
View File

@ -0,0 +1,78 @@
# Set environment variables
export DEBIAN_FRONTEND=noninteractive
export PYTHONUNBUFFERED=1
export CUDA_VISIBLE_DEVICES=0
# Install system dependencies
apt-get update && apt-get install -y \
wget \
curl \
git \
build-essential \
libsndfile1 \
ffmpeg \
sox \
alsa-utils \
pulseaudio \
&& rm -rf /var/lib/apt/lists/*
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Create virtual environment
apt install python3.12-venv python3.12-dev
python3.12 -m venv ~/venv-tts-kyutai
source ~/venv-tts-kyutai/bin/activate
# Install Python dependencies first (for better caching)
pip install --no-cache-dir --upgrade pip
# Install PyTorch with CUDA support for Python 3.12
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install core dependencies
pip install --no-cache-dir \
numpy \
scipy \
librosa \
soundfile \
huggingface_hub \
einops \
transformers \
accelerate
# Install API dependencies
pip install --no-cache-dir \
fastapi \
uvicorn[standard] \
python-multipart \
pydantic
# Install moshi package with all dependencies (following Colab notebook)
pip install --no-cache-dir 'sphn<0.2'
pip install --no-cache-dir "moshi==0.2.8"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp /app/kyutai-repo/scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Download the Kyutai delayed-streams-modeling repository
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
# Copy the TTS script from the repository
cp scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
# Create directories for input/output
mkdir -p /app/input /app/output /app/scripts /app/api_output
# Start TTS-Server
python /app/api_server.py

View File

@ -24,13 +24,18 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("in_file", help="The file to transcribe.") parser.add_argument("in_file", help="The file to transcribe.")
parser.add_argument("--max-steps", default=4096) parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx") parser.add_argument("--hf-repo")
parser.add_argument( parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
) )
args = parser.parse_args() args = parser.parse_args()
audio, _ = sphn.read(args.in_file, sample_rate=24000) audio, _ = sphn.read(args.in_file, sample_rate=24000)
if args.hf_repo is None:
if args.vad:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
else:
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
lm_config = hf_hub_download(args.hf_repo, "config.json") lm_config = hf_hub_download(args.hf_repo, "config.json")
with open(lm_config, "r") as fobj: with open(lm_config, "r") as fobj:
lm_config = json.load(fobj) lm_config = json.load(fobj)

View File

@ -128,6 +128,9 @@ def tokens_to_timestamped_text(
def main(args): def main(args):
if args.vad and args.hf_repo is None:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
info = moshi.models.loaders.CheckpointInfo.from_hf_repo( info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
args.hf_repo, args.hf_repo,
moshi_weights=args.moshi_weight, moshi_weights=args.moshi_weight,

View File

@ -25,12 +25,17 @@ from moshi_mlx import models, utils
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--max-steps", default=4096) parser.add_argument("--max-steps", default=4096)
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx") parser.add_argument("--hf-repo")
parser.add_argument( parser.add_argument(
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)." "--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
) )
args = parser.parse_args() args = parser.parse_args()
if args.hf_repo is None:
if args.vad:
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
else:
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
lm_config = hf_hub_download(args.hf_repo, "config.json") lm_config = hf_hub_download(args.hf_repo, "config.json")
with open(lm_config, "r") as fobj: with open(lm_config, "r") as fobj:
lm_config = json.load(fobj) lm_config = json.load(fobj)

View File

@ -76,6 +76,9 @@ def main():
moshi_weights = hf_get(moshi_name, args.hf_repo) moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config) lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config) model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16) model.set_dtype(mx.bfloat16)

View File

@ -205,6 +205,9 @@ def main():
moshi_weights = hf_get(moshi_name, args.hf_repo) moshi_weights = hf_get(moshi_name, args.hf_repo)
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo) tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
lm_config = models.LmConfig.from_config_dict(raw_config) lm_config = models.LmConfig.from_config_dict(raw_config)
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
# The following line gets around it for now.
lm_config.transformer.max_seq_len = lm_config.transformer.context
model = models.Lm(lm_config) model = models.Lm(lm_config)
model.set_dtype(mx.bfloat16) model.set_dtype(mx.bfloat16)

View File

@ -68,6 +68,9 @@ def main():
# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...] # If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
entries = tts_model.prepare_script([text], padding_between=1) entries = tts_model.prepare_script([text], padding_between=1)
if args.voice.endswith(".safetensors"):
voice_path = args.voice
else:
voice_path = tts_model.get_voice_path(args.voice) voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation, # CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time. # so it's not _actually_ doing CFG at inference time.
@ -75,6 +78,7 @@ def main():
condition_attributes = tts_model.make_condition_attributes( condition_attributes = tts_model.make_condition_attributes(
[voice_path], cfg_coef=2.0 [voice_path], cfg_coef=2.0
) )
_frames_cnt = 0
if args.out == "-": if args.out == "-":
# Stream the audio to the speakers using sounddevice. # Stream the audio to the speakers using sounddevice.
@ -83,9 +87,12 @@ def main():
pcms = queue.Queue() pcms = queue.Queue()
def _on_frame(frame): def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all(): if (frame != -1).all():
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy() pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1)) pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
def audio_callback(outdata, _a, _b, _c): def audio_callback(outdata, _a, _b, _c):
try: try:
@ -110,7 +117,16 @@ def main():
break break
time.sleep(1) time.sleep(1)
else: else:
result = tts_model.generate([entries], [condition_attributes])
def _on_frame(frame):
nonlocal _frames_cnt
if (frame != -1).all():
_frames_cnt += 1
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
result = tts_model.generate(
[entries], [condition_attributes], on_frame=_on_frame
)
with tts_model.mimi.streaming(1), torch.no_grad(): with tts_model.mimi.streaming(1), torch.no_grad():
pcms = [] pcms = []
for frame in result.frames[tts_model.delay_steps :]: for frame in result.frames[tts_model.delay_steps :]:

View File

@ -183,6 +183,9 @@ def main():
checkpoint_info, n_q=32, temp=0.6, device=args.device checkpoint_info, n_q=32, temp=0.6, device=args.device
) )
if args.voice.endswith(".safetensors"):
voice_path = args.voice
else:
voice_path = tts_model.get_voice_path(args.voice) voice_path = tts_model.get_voice_path(args.voice)
# CFG coef goes here because the model was trained with CFG distillation, # CFG coef goes here because the model was trained with CFG distillation,
# so it's not _actually_ doing CFG at inference time. # so it's not _actually_ doing CFG at inference time.

View File

@ -89,6 +89,45 @@ async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None])
print(f"Saved audio to {out}") print(f"Saved audio to {out}")
async def read_lines_from_stdin():
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
loop = asyncio.get_running_loop()
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
while True:
line = await reader.readline()
if not line:
break
yield line.decode().rstrip()
async def read_lines_from_file(path: str):
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def producer():
with open(path, "r", encoding="utf-8") as f:
for line in f:
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
await asyncio.to_thread(producer)
while True:
line = await queue.get()
if line is None:
break
yield line
async def get_lines(source: str):
if source == "-":
async for line in read_lines_from_stdin():
yield line
else:
async for line in read_lines_from_file(source):
yield line
async def websocket_client(): async def websocket_client():
parser = argparse.ArgumentParser(description="Use the TTS streaming API") parser = argparse.ArgumentParser(description="Use the TTS streaming API")
parser.add_argument("inp", type=str, help="Input file, use - for stdin.") parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
@ -113,25 +152,26 @@ async def websocket_client():
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}" uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
print(uri) print(uri)
# TODO: stream the text instead of sending it all at once
if args.inp == "-": if args.inp == "-":
if sys.stdin.isatty(): # Interactive if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):") print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
headers = {"kyutai-api-key": args.api_key} headers = {"kyutai-api-key": args.api_key}
async with websockets.connect(uri, additional_headers=headers) as websocket: async with websockets.connect(uri, additional_headers=headers) as websocket:
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts})) print("connected")
async def send_loop():
print("go send")
async for line in get_lines(args.inp):
for word in line.split():
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
await websocket.send(msgpack.packb({"type": "Eos"})) await websocket.send(msgpack.packb({"type": "Eos"}))
output_queue = asyncio.Queue() output_queue = asyncio.Queue()
receive_task = asyncio.create_task(receive_messages(websocket, output_queue)) receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue)) output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
await asyncio.gather(receive_task, output_audio_task) send_task = asyncio.create_task(send_loop())
await asyncio.gather(receive_task, output_audio_task, send_task)
if __name__ == "__main__": if __name__ == "__main__":