Compare commits
14 Commits
bump-versi
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 96042c9983 | |||
| 7e92c6f28a | |||
| 8e51c6eab9 | |||
| 34e58b8055 | |||
| f23d1be027 | |||
|
|
cf97f8d863 | ||
|
|
09468c239a | ||
|
|
07729ed47e | ||
|
|
af2283de3f | ||
|
|
7dc926d50c | ||
|
|
ab8e8c59b7 | ||
|
|
5f17114618 | ||
|
|
405a82ba3f | ||
|
|
3b584b100c |
83
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
83
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
name: Bug Report
|
||||||
|
description: You found a bug.
|
||||||
|
labels: ["bug", "triage"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
|
||||||
|
- type: dropdown
|
||||||
|
id: backend
|
||||||
|
attributes:
|
||||||
|
label: Backend impacted
|
||||||
|
description: Which backend is concerned with your bug report?
|
||||||
|
options:
|
||||||
|
- The PyTorch implementation
|
||||||
|
- The MLX implementation
|
||||||
|
- The Rust implementation
|
||||||
|
- Other / All
|
||||||
|
default: 0
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: dropdown
|
||||||
|
id: os
|
||||||
|
attributes:
|
||||||
|
label: Operating system
|
||||||
|
description: What is your operating system?
|
||||||
|
options:
|
||||||
|
- Linux
|
||||||
|
- Mac OS X
|
||||||
|
- Windows (unsupported)
|
||||||
|
default: 0
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: dropdown
|
||||||
|
id: hardware
|
||||||
|
attributes:
|
||||||
|
label: Hardware
|
||||||
|
description: What hardware are you using?
|
||||||
|
options:
|
||||||
|
- CPU
|
||||||
|
- GPU with CUDA
|
||||||
|
- Metal with MLX
|
||||||
|
default: 0
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: description
|
||||||
|
attributes:
|
||||||
|
label: Description
|
||||||
|
description: Provide a detailed description of your bug.
|
||||||
|
placeholder:
|
||||||
|
value:
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: more_info
|
||||||
|
attributes:
|
||||||
|
label: Extra information
|
||||||
|
description: Please provide any other relevant information, such as log extracts, code etc.
|
||||||
|
placeholder:
|
||||||
|
value:
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: env
|
||||||
|
attributes:
|
||||||
|
label: Environment
|
||||||
|
description: Please provide any other relevant information, such as log extracts, code etc.
|
||||||
|
placeholder:
|
||||||
|
value: |
|
||||||
|
Fill in the following information on your system.
|
||||||
|
- Operating system version:
|
||||||
|
|
||||||
|
If the backend impacted is PyTorch:
|
||||||
|
- Python version:
|
||||||
|
- PyTorch version:
|
||||||
|
- CUDA version (run `python -c 'import torch; print(torch.version.cuda)'`):
|
||||||
|
- GPU model and memory:
|
||||||
|
|
||||||
|
If the backend is MLX:
|
||||||
|
- Mac model:
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
40
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
40
.github/ISSUE_TEMPLATE/question.yml
vendored
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
name: Question
|
||||||
|
description: You have a question about the codebase, the paper, or the implementation.
|
||||||
|
labels: ["question", "triage"]
|
||||||
|
body:
|
||||||
|
- type: markdown
|
||||||
|
attributes:
|
||||||
|
value: |
|
||||||
|
Please first check the [FAQ](https://github.com/kyutai-labs/delayed-streams-modeling/blob/main/FAQ.md).
|
||||||
|
- type: checkboxes
|
||||||
|
id: terms
|
||||||
|
attributes:
|
||||||
|
label: Due diligence
|
||||||
|
description: Have you searched the existing issues / FAQ / Google / asked ChatGPT?
|
||||||
|
options:
|
||||||
|
- label: I have done my due diligence in trying to find the answer myself.
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: dropdown
|
||||||
|
id: backend
|
||||||
|
attributes:
|
||||||
|
label: Topic
|
||||||
|
description: What is your question about?
|
||||||
|
options:
|
||||||
|
- The paper
|
||||||
|
- The PyTorch implementation
|
||||||
|
- The MLX implementation
|
||||||
|
- The Rust implementation
|
||||||
|
- Other / All
|
||||||
|
default: 0
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
- type: textarea
|
||||||
|
id: question
|
||||||
|
attributes:
|
||||||
|
label: Question
|
||||||
|
description: What is your question?
|
||||||
|
placeholder: Your question. Please make sure this is directly related to our codebase. We will not provide support for installing PyTorch, CUDA, Rust etc.
|
||||||
|
value:
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -192,3 +192,4 @@ cython_debug/
|
||||||
# refer to https://docs.cursor.com/context/ignore-files
|
# refer to https://docs.cursor.com/context/ignore-files
|
||||||
.cursorignore
|
.cursorignore
|
||||||
.cursorindexingignore
|
.cursorindexingignore
|
||||||
|
out*.wav
|
||||||
|
|
|
||||||
56
FAQ.md
Normal file
56
FAQ.md
Normal file
|
|
@ -0,0 +1,56 @@
|
||||||
|
# FAQ
|
||||||
|
|
||||||
|
Here is the answer to a number of frequently asked questions.
|
||||||
|
|
||||||
|
### Torch compilation issues
|
||||||
|
|
||||||
|
With some PyTorch/triton versions, one might encounter compilation errors
|
||||||
|
like the following:
|
||||||
|
```
|
||||||
|
Traceback (most recent call last):
|
||||||
|
...
|
||||||
|
File "site-packages/torch/_inductor/runtime/triton_heuristics.py", line 1153, in make_launcher
|
||||||
|
"launch_enter_hook": binary.__class__.launch_enter_hook,
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
torch._inductor.exc.InductorError: AttributeError: type object 'CompiledKernel' has no attribute 'launch_enter_hook'
|
||||||
|
```
|
||||||
|
|
||||||
|
If that's the case, you can disable torch compilation by setting the following
|
||||||
|
environment variable.
|
||||||
|
```bash
|
||||||
|
export NO_TORCH_COMPILE=1
|
||||||
|
```
|
||||||
|
|
||||||
|
### Issues installing the sentencepiece dependency
|
||||||
|
|
||||||
|
On some linux distributions (arch) or on macos, the local version of cmake can
|
||||||
|
be too recent for the sentencepiece dependency.
|
||||||
|
|
||||||
|
```
|
||||||
|
CMake Error at CMakeLists.txt:15 (cmake_minimum_required):
|
||||||
|
Compatibility with CMake < 3.5 has been removed from CMake.
|
||||||
|
```
|
||||||
|
|
||||||
|
You can either downgrade your cmake version, e.g. 3.31.0 on arch works or try
|
||||||
|
setting `CMAKE_POLICY_VERSION_MINIMUM=3.5`.
|
||||||
|
|
||||||
|
If you run into some errors when compiling the sentencepiece rust bindings,
|
||||||
|
these could also be due to gcc being too recent, e.g. gcc 15. You can get
|
||||||
|
around this by using gcc-13, e.g. by setting the following after installing
|
||||||
|
the proper gcc packages.
|
||||||
|
```bash
|
||||||
|
export CMAKE_C_COMPILER=/usr/bin/gcc-13
|
||||||
|
export CMAKE_CXX_COMPILER=/usr/bin/g++-13
|
||||||
|
CC=gcc-13 CXX=g++-13 cargo build --release
|
||||||
|
```
|
||||||
|
|
||||||
|
Alternatively you can set `CXXFLAGS="-include cstdint"`, see this
|
||||||
|
[issue](https://github.com/google/sentencepiece/issues/1108).
|
||||||
|
|
||||||
|
### Will you release training code?
|
||||||
|
|
||||||
|
Some finetuning code can be found in the [kyutai-labs/moshi-finetune repo](https://github.com/kyutai-labs/moshi-finetune).
|
||||||
|
This code has not been adapted to the Speech-To-Text and Text-To-Speech models
|
||||||
|
yet, but it should be a good starting point.
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -305,6 +305,10 @@ If you have [uv](https://docs.astral.sh/uv/) installed, you can skip the install
|
||||||
and just prefix the command above with `uvx --with moshi-mlx`.
|
and just prefix the command above with `uvx --with moshi-mlx`.
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
## FAQ
|
||||||
|
|
||||||
|
Checkout the [Frequently Asked Questions](FAQ.md) section before opening an issue.
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.
|
The present code is provided under the MIT license for the Python parts, and Apache license for the Rust backend.
|
||||||
|
|
|
||||||
298
app/api_server.py
Normal file
298
app/api_server.py
Normal file
|
|
@ -0,0 +1,298 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
OpenAI-Compatible Kyutai TTS API Server with Model Caching
|
||||||
|
Improved version that loads the model once and keeps it in memory
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
import time
|
||||||
|
import asyncio
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Literal
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import soundfile as sf
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.responses import Response
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Global model variables - loaded once at startup
|
||||||
|
tts_model = None
|
||||||
|
device = None
|
||||||
|
sample_rate = None
|
||||||
|
|
||||||
|
class SpeechRequest(BaseModel):
|
||||||
|
model: Literal["tts-1", "tts-1-hd"] = Field("tts-1", description="TTS model to use")
|
||||||
|
input: str = Field(..., min_length=1, max_length=4096, description="Text to generate audio for")
|
||||||
|
voice: Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"] = Field("alloy", description="Voice to use")
|
||||||
|
response_format: Literal["mp3", "opus", "aac", "flac", "wav", "pcm"] = Field("mp3", description="Audio format")
|
||||||
|
speed: Optional[float] = Field(1.0, ge=0.25, le=4.0, description="Speed of generated audio")
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="OpenAI-Compatible TTS API (Cached)",
|
||||||
|
description="OpenAI Audio Speech API compatible endpoint using Kyutai TTS with model caching",
|
||||||
|
version="2.0.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
OUTPUT_DIR = Path("/app/api_output")
|
||||||
|
OUTPUT_DIR.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
def load_tts_model():
|
||||||
|
"""Load TTS model once at startup and keep in memory"""
|
||||||
|
global tts_model, device, sample_rate
|
||||||
|
|
||||||
|
if tts_model is not None:
|
||||||
|
logger.info("TTS model already loaded")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info("🚀 Loading Kyutai TTS model (one-time initialization)...")
|
||||||
|
|
||||||
|
# Import Kyutai TTS modules
|
||||||
|
from moshi.models.loaders import CheckpointInfo
|
||||||
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, TTSModel
|
||||||
|
|
||||||
|
# Set device
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
logger.info(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Load the TTS model
|
||||||
|
checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)
|
||||||
|
|
||||||
|
tts_model = TTSModel.from_checkpoint_info(
|
||||||
|
checkpoint_info,
|
||||||
|
n_q=32,
|
||||||
|
temp=0.6,
|
||||||
|
device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get sample rate
|
||||||
|
sample_rate = tts_model.mimi.sample_rate
|
||||||
|
|
||||||
|
logger.info(f"✅ TTS model loaded successfully!")
|
||||||
|
logger.info(f" Model: {DEFAULT_DSM_TTS_REPO}")
|
||||||
|
logger.info(f" Device: {device}")
|
||||||
|
logger.info(f" Sample Rate: {sample_rate}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ Failed to load TTS model: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def generate_audio_fast(text: str, voice: str = "alloy", speed: float = 1.0) -> bytes:
|
||||||
|
"""Generate audio using cached TTS model"""
|
||||||
|
global tts_model, device, sample_rate
|
||||||
|
|
||||||
|
if tts_model is None:
|
||||||
|
raise HTTPException(status_code=500, detail="TTS model not loaded")
|
||||||
|
|
||||||
|
try:
|
||||||
|
logger.info(f"🎵 Generating audio for: '{text[:50]}{'...' if len(text) > 50 else ''}'")
|
||||||
|
|
||||||
|
# Prepare the script (text input)
|
||||||
|
entries = tts_model.prepare_script([text], padding_between=1)
|
||||||
|
|
||||||
|
# Voice mapping for OpenAI compatibility
|
||||||
|
voice_mapping = {
|
||||||
|
"alloy": "expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"echo": "expresso/ex04-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"fable": "expresso/ex05-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"onyx": "expresso/ex06-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"nova": "expresso/ex07-ex01_happy_001_channel1_334s.wav",
|
||||||
|
"shimmer": "expresso/ex08-ex01_happy_001_channel1_334s.wav"
|
||||||
|
}
|
||||||
|
|
||||||
|
selected_voice = voice_mapping.get(voice, voice_mapping["alloy"])
|
||||||
|
|
||||||
|
try:
|
||||||
|
voice_path = tts_model.get_voice_path(selected_voice)
|
||||||
|
except:
|
||||||
|
# Fallback to default if voice not found
|
||||||
|
voice_path = tts_model.get_voice_path("expresso/ex03-ex01_happy_001_channel1_334s.wav")
|
||||||
|
|
||||||
|
# Prepare condition attributes
|
||||||
|
condition_attributes = tts_model.make_condition_attributes(
|
||||||
|
[voice_path], cfg_coef=2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate audio
|
||||||
|
pcms = []
|
||||||
|
|
||||||
|
def on_frame(frame):
|
||||||
|
if (frame != -1).all():
|
||||||
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
|
pcms.append(torch.clamp(torch.from_numpy(pcm[0, 0]), -1, 1).numpy())
|
||||||
|
|
||||||
|
all_entries = [entries]
|
||||||
|
all_condition_attributes = [condition_attributes]
|
||||||
|
|
||||||
|
with tts_model.mimi.streaming(len(all_entries)):
|
||||||
|
result = tts_model.generate(all_entries, all_condition_attributes, on_frame=on_frame)
|
||||||
|
|
||||||
|
# Concatenate all audio frames
|
||||||
|
if pcms:
|
||||||
|
import numpy as np
|
||||||
|
audio = np.concatenate(pcms, axis=-1)
|
||||||
|
|
||||||
|
# Apply speed adjustment if needed
|
||||||
|
if speed != 1.0:
|
||||||
|
# Simple speed adjustment by resampling
|
||||||
|
from scipy import signal
|
||||||
|
audio_length = len(audio)
|
||||||
|
new_length = int(audio_length / speed)
|
||||||
|
audio = signal.resample(audio, new_length)
|
||||||
|
|
||||||
|
# Convert to bytes
|
||||||
|
audio_bytes = io.BytesIO()
|
||||||
|
sf.write(audio_bytes, audio, samplerate=sample_rate, format='WAV')
|
||||||
|
audio_bytes.seek(0)
|
||||||
|
|
||||||
|
logger.info(f"✅ Audio generated successfully ({len(audio)/sample_rate:.2f}s)")
|
||||||
|
return audio_bytes.read()
|
||||||
|
else:
|
||||||
|
raise Exception("No audio frames generated")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"❌ TTS generation error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Audio generation failed: {str(e)}")
|
||||||
|
|
||||||
|
def convert_audio_format(audio_wav_bytes: bytes, output_format: str) -> bytes:
|
||||||
|
"""Convert WAV audio to requested format using ffmpeg"""
|
||||||
|
try:
|
||||||
|
if output_format == "wav":
|
||||||
|
return audio_wav_bytes
|
||||||
|
|
||||||
|
# Use ffmpeg to convert
|
||||||
|
cmd = ["ffmpeg", "-f", "wav", "-i", "pipe:0", "-f", output_format, "pipe:1"]
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
input=audio_wav_bytes,
|
||||||
|
capture_output=True,
|
||||||
|
check=True
|
||||||
|
)
|
||||||
|
return result.stdout
|
||||||
|
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
logger.error(f"Audio conversion failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Audio conversion failed: {e}")
|
||||||
|
|
||||||
|
@app.post("/v1/audio/speech")
|
||||||
|
async def create_speech(request: SpeechRequest):
|
||||||
|
"""
|
||||||
|
OpenAI-compatible audio speech endpoint
|
||||||
|
Uses cached TTS model for fast generation
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
# Generate audio with cached model
|
||||||
|
audio_wav_bytes = generate_audio_fast(
|
||||||
|
text=request.input,
|
||||||
|
voice=request.voice,
|
||||||
|
speed=request.speed
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert to requested format
|
||||||
|
audio_data = convert_audio_format(audio_wav_bytes, request.response_format)
|
||||||
|
|
||||||
|
generation_time = time.time() - start_time
|
||||||
|
logger.info(f"⚡ Total generation time: {generation_time:.2f}s")
|
||||||
|
|
||||||
|
# Set appropriate content type
|
||||||
|
content_types = {
|
||||||
|
"mp3": "audio/mpeg",
|
||||||
|
"opus": "audio/opus",
|
||||||
|
"aac": "audio/aac",
|
||||||
|
"flac": "audio/flac",
|
||||||
|
"wav": "audio/wav",
|
||||||
|
"pcm": "audio/pcm"
|
||||||
|
}
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=audio_data,
|
||||||
|
media_type=content_types.get(request.response_format, "audio/wav"),
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f"attachment; filename=speech.{request.response_format}",
|
||||||
|
"X-Generation-Time": str(generation_time)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Speech generation failed: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
@app.get("/v1/models")
|
||||||
|
async def list_models():
|
||||||
|
"""List available models (OpenAI-compatible)"""
|
||||||
|
return {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "tts-1",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1677610602,
|
||||||
|
"owned_by": "kyutai",
|
||||||
|
"permission": [],
|
||||||
|
"root": "tts-1",
|
||||||
|
"parent": None
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "tts-1-hd",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1677610602,
|
||||||
|
"owned_by": "kyutai",
|
||||||
|
"permission": [],
|
||||||
|
"root": "tts-1-hd",
|
||||||
|
"parent": None
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
"""Health check endpoint with model status"""
|
||||||
|
model_loaded = tts_model is not None
|
||||||
|
return {
|
||||||
|
"status": "healthy" if model_loaded else "loading",
|
||||||
|
"model_loaded": model_loaded,
|
||||||
|
"cuda_available": torch.cuda.is_available(),
|
||||||
|
"device": str(device) if device else None,
|
||||||
|
"service": "kyutai-tts-openai-compatible-cached"
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/reload-model")
|
||||||
|
async def reload_model():
|
||||||
|
"""Reload the TTS model (admin endpoint)"""
|
||||||
|
global tts_model
|
||||||
|
try:
|
||||||
|
tts_model = None
|
||||||
|
load_tts_model()
|
||||||
|
return {"status": "success", "message": "Model reloaded successfully"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"status": "error", "message": str(e)}
|
||||||
|
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
"""Load model on startup"""
|
||||||
|
logger.info("🚀 Starting TTS API server with model caching...")
|
||||||
|
load_tts_model()
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||||
67
app/dependency_check.py
Normal file
67
app/dependency_check.py
Normal file
|
|
@ -0,0 +1,67 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Check if all Kyutai TTS dependencies are properly installed
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
def check_dependencies():
|
||||||
|
print("🔍 Checking Kyutai TTS Dependencies")
|
||||||
|
print("=" * 40)
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"torch",
|
||||||
|
"numpy",
|
||||||
|
"einops",
|
||||||
|
"transformers",
|
||||||
|
"accelerate",
|
||||||
|
"soundfile",
|
||||||
|
"librosa",
|
||||||
|
"huggingface_hub",
|
||||||
|
"moshi",
|
||||||
|
"sphn"
|
||||||
|
]
|
||||||
|
|
||||||
|
missing = []
|
||||||
|
installed = []
|
||||||
|
|
||||||
|
for dep in dependencies:
|
||||||
|
try:
|
||||||
|
__import__(dep)
|
||||||
|
installed.append(dep)
|
||||||
|
print(f"✓ {dep}")
|
||||||
|
except ImportError as e:
|
||||||
|
missing.append((dep, str(e)))
|
||||||
|
print(f"✗ {dep}: {e}")
|
||||||
|
|
||||||
|
print(f"\n📊 Summary:")
|
||||||
|
print(f"✓ Installed: {len(installed)}")
|
||||||
|
print(f"✗ Missing: {len(missing)}")
|
||||||
|
|
||||||
|
if missing:
|
||||||
|
print(f"\n🔧 To fix missing dependencies:")
|
||||||
|
for dep, error in missing:
|
||||||
|
print(f"pip install {dep}")
|
||||||
|
|
||||||
|
print(f"\n🧪 Testing Kyutai TTS imports:")
|
||||||
|
try:
|
||||||
|
from moshi.models.loaders import CheckpointInfo
|
||||||
|
print("✓ CheckpointInfo import successful")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ CheckpointInfo import failed: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||||
|
print("✓ TTSModel imports successful")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ TTSModel imports failed: {e}")
|
||||||
|
|
||||||
|
return len(missing) == 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
success = check_dependencies()
|
||||||
|
if success:
|
||||||
|
print("\n🎉 All dependencies are installed correctly!")
|
||||||
|
else:
|
||||||
|
print("\n❌ Some dependencies are missing. Please install them first.")
|
||||||
|
sys.exit(1)
|
||||||
59
app/scripts/tts_runner.py
Normal file
59
app/scripts/tts_runner.py
Normal file
|
|
@ -0,0 +1,59 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Kyutai TTS PyTorch Runner
|
||||||
|
Dockerized implementation for text-to-speech generation
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import argparse
|
||||||
|
import torch
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Kyutai TTS PyTorch Runner')
|
||||||
|
parser.add_argument('input_file', help='Input text file or "-" for stdin')
|
||||||
|
parser.add_argument('output_file', help='Output audio file')
|
||||||
|
parser.add_argument('--model', default='kyutai/tts-1.6b-en_fr', help='TTS model to use')
|
||||||
|
parser.add_argument('--device', default='cuda' if torch.cuda.is_available() else 'cpu', help='Device to use')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(f"Using device: {args.device}")
|
||||||
|
print(f"CUDA available: {torch.cuda.is_available()}")
|
||||||
|
|
||||||
|
# Handle stdin input
|
||||||
|
if args.input_file == '-':
|
||||||
|
# Read from stdin and create temporary file
|
||||||
|
text = sys.stdin.read().strip()
|
||||||
|
temp_file = '/tmp/temp_input.txt'
|
||||||
|
with open(temp_file, 'w') as f:
|
||||||
|
f.write(text)
|
||||||
|
input_file = temp_file
|
||||||
|
else:
|
||||||
|
input_file = args.input_file
|
||||||
|
|
||||||
|
# Check if the original TTS script exists
|
||||||
|
tts_script = Path('/app/scripts/tts_pytorch.py')
|
||||||
|
if tts_script.exists():
|
||||||
|
print("Using original TTS script from Kyutai repository")
|
||||||
|
import subprocess
|
||||||
|
cmd = ['python', str(tts_script), input_file, args.output_file]
|
||||||
|
subprocess.run(cmd, check=True)
|
||||||
|
else:
|
||||||
|
print("Using moshi package for TTS generation")
|
||||||
|
import subprocess
|
||||||
|
cmd = [
|
||||||
|
'python', '-m', 'moshi.run_inference',
|
||||||
|
'--hf-repo', args.model,
|
||||||
|
input_file,
|
||||||
|
args.output_file
|
||||||
|
]
|
||||||
|
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||||
|
if result.returncode != 0:
|
||||||
|
print(f"Error: {result.stderr}")
|
||||||
|
sys.exit(1)
|
||||||
|
print(f"Audio generated: {args.output_file}")
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
EOF
|
||||||
78
install.sh
Normal file
78
install.sh
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
# Set environment variables
|
||||||
|
export DEBIAN_FRONTEND=noninteractive
|
||||||
|
export PYTHONUNBUFFERED=1
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
apt-get update && apt-get install -y \
|
||||||
|
wget \
|
||||||
|
curl \
|
||||||
|
git \
|
||||||
|
build-essential \
|
||||||
|
libsndfile1 \
|
||||||
|
ffmpeg \
|
||||||
|
sox \
|
||||||
|
alsa-utils \
|
||||||
|
pulseaudio \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
||||||
|
# Install Python dependencies first (for better caching)
|
||||||
|
pip install --no-cache-dir --upgrade pip
|
||||||
|
|
||||||
|
# Create virtual environment
|
||||||
|
apt install python3.12-venv python3.12-dev
|
||||||
|
python3.12 -m venv ~/venv-tts-kyutai
|
||||||
|
source ~/venv-tts-kyutai/bin/activate
|
||||||
|
|
||||||
|
# Install Python dependencies first (for better caching)
|
||||||
|
pip install --no-cache-dir --upgrade pip
|
||||||
|
|
||||||
|
# Install PyTorch with CUDA support for Python 3.12
|
||||||
|
pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||||
|
|
||||||
|
# Install core dependencies
|
||||||
|
pip install --no-cache-dir \
|
||||||
|
numpy \
|
||||||
|
scipy \
|
||||||
|
librosa \
|
||||||
|
soundfile \
|
||||||
|
huggingface_hub \
|
||||||
|
einops \
|
||||||
|
transformers \
|
||||||
|
accelerate
|
||||||
|
|
||||||
|
# Install API dependencies
|
||||||
|
pip install --no-cache-dir \
|
||||||
|
fastapi \
|
||||||
|
uvicorn[standard] \
|
||||||
|
python-multipart \
|
||||||
|
pydantic
|
||||||
|
|
||||||
|
# Install moshi package with all dependencies (following Colab notebook)
|
||||||
|
pip install --no-cache-dir 'sphn<0.2'
|
||||||
|
pip install --no-cache-dir "moshi==0.2.8"
|
||||||
|
|
||||||
|
# Create directories for input/output
|
||||||
|
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||||
|
|
||||||
|
# Download the Kyutai delayed-streams-modeling repository
|
||||||
|
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
|
||||||
|
|
||||||
|
# Copy the TTS script from the repository
|
||||||
|
cp /app/kyutai-repo/scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
|
||||||
|
|
||||||
|
# Create directories for input/output
|
||||||
|
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||||
|
|
||||||
|
# Download the Kyutai delayed-streams-modeling repository
|
||||||
|
#git clone https://github.com/kyutai-labs/delayed-streams-modeling.git /app/kyutai-repo
|
||||||
|
|
||||||
|
# Copy the TTS script from the repository
|
||||||
|
cp scripts/tts_pytorch.py /app/scripts/ || echo "TTS script not found, will create custom one"
|
||||||
|
|
||||||
|
# Create directories for input/output
|
||||||
|
mkdir -p /app/input /app/output /app/scripts /app/api_output
|
||||||
|
|
||||||
|
# Start TTS-Server
|
||||||
|
python /app/api_server.py
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "huggingface_hub",
|
# "huggingface_hub",
|
||||||
# "moshi_mlx==0.2.10",
|
# "moshi_mlx==0.2.12",
|
||||||
# "numpy",
|
# "numpy",
|
||||||
# "sentencepiece",
|
# "sentencepiece",
|
||||||
# "sounddevice",
|
# "sounddevice",
|
||||||
|
|
@ -24,13 +24,18 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("in_file", help="The file to transcribe.")
|
parser.add_argument("in_file", help="The file to transcribe.")
|
||||||
parser.add_argument("--max-steps", default=4096)
|
parser.add_argument("--max-steps", default=4096)
|
||||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
parser.add_argument("--hf-repo")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
||||||
|
if args.hf_repo is None:
|
||||||
|
if args.vad:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
else:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
|
||||||
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||||
with open(lm_config, "r") as fobj:
|
with open(lm_config, "r") as fobj:
|
||||||
lm_config = json.load(fobj)
|
lm_config = json.load(fobj)
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@
|
||||||
# "julius",
|
# "julius",
|
||||||
# "librosa",
|
# "librosa",
|
||||||
# "soundfile",
|
# "soundfile",
|
||||||
# "moshi==0.2.9",
|
# "moshi==0.2.11",
|
||||||
# ]
|
# ]
|
||||||
# ///
|
# ///
|
||||||
|
|
||||||
|
|
@ -128,6 +128,9 @@ def tokens_to_timestamped_text(
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
if args.vad and args.hf_repo is None:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
|
||||||
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
|
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
|
||||||
args.hf_repo,
|
args.hf_repo,
|
||||||
moshi_weights=args.moshi_weight,
|
moshi_weights=args.moshi_weight,
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "huggingface_hub",
|
# "huggingface_hub",
|
||||||
# "moshi_mlx==0.2.10",
|
# "moshi_mlx==0.2.12",
|
||||||
# "numpy",
|
# "numpy",
|
||||||
# "rustymimi",
|
# "rustymimi",
|
||||||
# "sentencepiece",
|
# "sentencepiece",
|
||||||
|
|
@ -25,12 +25,17 @@ from moshi_mlx import models, utils
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--max-steps", default=4096)
|
parser.add_argument("--max-steps", default=4096)
|
||||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
parser.add_argument("--hf-repo")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.hf_repo is None:
|
||||||
|
if args.vad:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
else:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
|
||||||
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||||
with open(lm_config, "r") as fobj:
|
with open(lm_config, "r") as fobj:
|
||||||
lm_config = json.load(fobj)
|
lm_config = json.load(fobj)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "huggingface_hub",
|
# "huggingface_hub",
|
||||||
# "moshi_mlx==0.2.11",
|
# "moshi_mlx==0.2.12",
|
||||||
# "numpy",
|
# "numpy",
|
||||||
# "sounddevice",
|
# "sounddevice",
|
||||||
# ]
|
# ]
|
||||||
|
|
@ -76,6 +76,9 @@ def main():
|
||||||
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||||
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||||
lm_config = models.LmConfig.from_config_dict(raw_config)
|
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||||
|
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||||
|
# The following line gets around it for now.
|
||||||
|
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||||
model = models.Lm(lm_config)
|
model = models.Lm(lm_config)
|
||||||
model.set_dtype(mx.bfloat16)
|
model.set_dtype(mx.bfloat16)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "huggingface_hub",
|
# "huggingface_hub",
|
||||||
# "moshi_mlx==0.2.11",
|
# "moshi_mlx==0.2.12",
|
||||||
# "numpy",
|
# "numpy",
|
||||||
# "sounddevice",
|
# "sounddevice",
|
||||||
# ]
|
# ]
|
||||||
|
|
@ -205,6 +205,9 @@ def main():
|
||||||
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
moshi_weights = hf_get(moshi_name, args.hf_repo)
|
||||||
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
tokenizer = hf_get(raw_config["tokenizer_name"], args.hf_repo)
|
||||||
lm_config = models.LmConfig.from_config_dict(raw_config)
|
lm_config = models.LmConfig.from_config_dict(raw_config)
|
||||||
|
# There is a bug in moshi_mlx <= 0.3.0 handling of the ring kv cache.
|
||||||
|
# The following line gets around it for now.
|
||||||
|
lm_config.transformer.max_seq_len = lm_config.transformer.context
|
||||||
model = models.Lm(lm_config)
|
model = models.Lm(lm_config)
|
||||||
model.set_dtype(mx.bfloat16)
|
model.set_dtype(mx.bfloat16)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# /// script
|
# /// script
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "moshi==0.2.10",
|
# "moshi==0.2.11",
|
||||||
# "torch",
|
# "torch",
|
||||||
# "sphn",
|
# "sphn",
|
||||||
# "sounddevice",
|
# "sounddevice",
|
||||||
|
|
@ -68,6 +68,9 @@ def main():
|
||||||
|
|
||||||
# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
|
# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]
|
||||||
entries = tts_model.prepare_script([text], padding_between=1)
|
entries = tts_model.prepare_script([text], padding_between=1)
|
||||||
|
if args.voice.endswith(".safetensors"):
|
||||||
|
voice_path = args.voice
|
||||||
|
else:
|
||||||
voice_path = tts_model.get_voice_path(args.voice)
|
voice_path = tts_model.get_voice_path(args.voice)
|
||||||
# CFG coef goes here because the model was trained with CFG distillation,
|
# CFG coef goes here because the model was trained with CFG distillation,
|
||||||
# so it's not _actually_ doing CFG at inference time.
|
# so it's not _actually_ doing CFG at inference time.
|
||||||
|
|
@ -75,6 +78,7 @@ def main():
|
||||||
condition_attributes = tts_model.make_condition_attributes(
|
condition_attributes = tts_model.make_condition_attributes(
|
||||||
[voice_path], cfg_coef=2.0
|
[voice_path], cfg_coef=2.0
|
||||||
)
|
)
|
||||||
|
_frames_cnt = 0
|
||||||
|
|
||||||
if args.out == "-":
|
if args.out == "-":
|
||||||
# Stream the audio to the speakers using sounddevice.
|
# Stream the audio to the speakers using sounddevice.
|
||||||
|
|
@ -83,9 +87,12 @@ def main():
|
||||||
pcms = queue.Queue()
|
pcms = queue.Queue()
|
||||||
|
|
||||||
def _on_frame(frame):
|
def _on_frame(frame):
|
||||||
|
nonlocal _frames_cnt
|
||||||
if (frame != -1).all():
|
if (frame != -1).all():
|
||||||
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
pcm = tts_model.mimi.decode(frame[:, 1:, :]).cpu().numpy()
|
||||||
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
|
pcms.put_nowait(np.clip(pcm[0, 0], -1, 1))
|
||||||
|
_frames_cnt += 1
|
||||||
|
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
||||||
|
|
||||||
def audio_callback(outdata, _a, _b, _c):
|
def audio_callback(outdata, _a, _b, _c):
|
||||||
try:
|
try:
|
||||||
|
|
@ -110,7 +117,16 @@ def main():
|
||||||
break
|
break
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
else:
|
else:
|
||||||
result = tts_model.generate([entries], [condition_attributes])
|
|
||||||
|
def _on_frame(frame):
|
||||||
|
nonlocal _frames_cnt
|
||||||
|
if (frame != -1).all():
|
||||||
|
_frames_cnt += 1
|
||||||
|
print(f"generated {_frames_cnt / 12.5:.2f}s", end="\r", flush=True)
|
||||||
|
|
||||||
|
result = tts_model.generate(
|
||||||
|
[entries], [condition_attributes], on_frame=_on_frame
|
||||||
|
)
|
||||||
with tts_model.mimi.streaming(1), torch.no_grad():
|
with tts_model.mimi.streaming(1), torch.no_grad():
|
||||||
pcms = []
|
pcms = []
|
||||||
for frame in result.frames[tts_model.delay_steps :]:
|
for frame in result.frames[tts_model.delay_steps :]:
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# /// script
|
# /// script
|
||||||
# requires-python = ">=3.12"
|
# requires-python = ">=3.12"
|
||||||
# dependencies = [
|
# dependencies = [
|
||||||
# "moshi==0.2.10",
|
# "moshi==0.2.11",
|
||||||
# "torch",
|
# "torch",
|
||||||
# "sphn",
|
# "sphn",
|
||||||
# "sounddevice",
|
# "sounddevice",
|
||||||
|
|
@ -183,6 +183,9 @@ def main():
|
||||||
checkpoint_info, n_q=32, temp=0.6, device=args.device
|
checkpoint_info, n_q=32, temp=0.6, device=args.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.voice.endswith(".safetensors"):
|
||||||
|
voice_path = args.voice
|
||||||
|
else:
|
||||||
voice_path = tts_model.get_voice_path(args.voice)
|
voice_path = tts_model.get_voice_path(args.voice)
|
||||||
# CFG coef goes here because the model was trained with CFG distillation,
|
# CFG coef goes here because the model was trained with CFG distillation,
|
||||||
# so it's not _actually_ doing CFG at inference time.
|
# so it's not _actually_ doing CFG at inference time.
|
||||||
|
|
|
||||||
|
|
@ -89,6 +89,45 @@ async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None])
|
||||||
print(f"Saved audio to {out}")
|
print(f"Saved audio to {out}")
|
||||||
|
|
||||||
|
|
||||||
|
async def read_lines_from_stdin():
|
||||||
|
reader = asyncio.StreamReader()
|
||||||
|
protocol = asyncio.StreamReaderProtocol(reader)
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
||||||
|
while True:
|
||||||
|
line = await reader.readline()
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
yield line.decode().rstrip()
|
||||||
|
|
||||||
|
|
||||||
|
async def read_lines_from_file(path: str):
|
||||||
|
queue = asyncio.Queue()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def producer():
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
|
||||||
|
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||||||
|
|
||||||
|
await asyncio.to_thread(producer)
|
||||||
|
while True:
|
||||||
|
line = await queue.get()
|
||||||
|
if line is None:
|
||||||
|
break
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
|
async def get_lines(source: str):
|
||||||
|
if source == "-":
|
||||||
|
async for line in read_lines_from_stdin():
|
||||||
|
yield line
|
||||||
|
else:
|
||||||
|
async for line in read_lines_from_file(source):
|
||||||
|
yield line
|
||||||
|
|
||||||
|
|
||||||
async def websocket_client():
|
async def websocket_client():
|
||||||
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
|
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
|
||||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||||
|
|
@ -113,25 +152,26 @@ async def websocket_client():
|
||||||
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
|
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
|
||||||
print(uri)
|
print(uri)
|
||||||
|
|
||||||
# TODO: stream the text instead of sending it all at once
|
|
||||||
if args.inp == "-":
|
if args.inp == "-":
|
||||||
if sys.stdin.isatty(): # Interactive
|
if sys.stdin.isatty(): # Interactive
|
||||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
text_to_tts = sys.stdin.read().strip()
|
|
||||||
else:
|
|
||||||
with open(args.inp, "r") as fobj:
|
|
||||||
text_to_tts = fobj.read().strip()
|
|
||||||
|
|
||||||
headers = {"kyutai-api-key": args.api_key}
|
headers = {"kyutai-api-key": args.api_key}
|
||||||
|
|
||||||
async with websockets.connect(uri, additional_headers=headers) as websocket:
|
async with websockets.connect(uri, additional_headers=headers) as websocket:
|
||||||
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
|
print("connected")
|
||||||
|
|
||||||
|
async def send_loop():
|
||||||
|
print("go send")
|
||||||
|
async for line in get_lines(args.inp):
|
||||||
|
for word in line.split():
|
||||||
|
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
|
||||||
await websocket.send(msgpack.packb({"type": "Eos"}))
|
await websocket.send(msgpack.packb({"type": "Eos"}))
|
||||||
|
|
||||||
output_queue = asyncio.Queue()
|
output_queue = asyncio.Queue()
|
||||||
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
|
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
|
||||||
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
|
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
|
||||||
await asyncio.gather(receive_task, output_audio_task)
|
send_task = asyncio.create_task(send_loop())
|
||||||
|
await asyncio.gather(receive_task, output_audio_task, send_task)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,9 @@
|
||||||
"source": [
|
"source": [
|
||||||
"# Fast install, might break in the future.\n",
|
"# Fast install, might break in the future.\n",
|
||||||
"!pip install 'sphn<0.2'\n",
|
"!pip install 'sphn<0.2'\n",
|
||||||
"!pip install --no-deps \"moshi==0.2.10\"\n",
|
"!pip install --no-deps \"moshi==0.2.11\"\n",
|
||||||
"# Slow install (will download torch and cuda), but future proof.\n",
|
"# Slow install (will download torch and cuda), but future proof.\n",
|
||||||
"# !pip install \"moshi==0.2.10\""
|
"# !pip install \"moshi==0.2.11\""
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user