Compare commits
2 Commits
main
...
vv/tts-pyt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31b425cb6f | ||
|
|
3de0606614 |
3
.vscode/settings.json
vendored
Normal file
3
.vscode/settings.json
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
{
|
||||||
|
"python.analysis.typeCheckingMode": "standard"
|
||||||
|
}
|
||||||
106
scripts/tts_pytorch.py
Normal file
106
scripts/tts_pytorch.py
Normal file
|
|
@ -0,0 +1,106 @@
|
||||||
|
# /// script
|
||||||
|
# requires-python = ">=3.12"
|
||||||
|
# dependencies = [
|
||||||
|
# "moshi @ git+https://git@github.com/kyutai-labs/moshi#egg=moshi&subdirectory=moshi",
|
||||||
|
# "torch",
|
||||||
|
# "sphn",
|
||||||
|
# "sounddevice",
|
||||||
|
# ]
|
||||||
|
# ///
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import sphn
|
||||||
|
import torch
|
||||||
|
from moshi.models.loaders import CheckpointInfo
|
||||||
|
from moshi.models.tts import DEFAULT_DSM_TTS_REPO, DEFAULT_DSM_TTS_VOICE_REPO, TTSModel
|
||||||
|
|
||||||
|
|
||||||
|
def audio_to_int16(audio: np.ndarray) -> np.ndarray:
|
||||||
|
if audio.dtype == np.int16:
|
||||||
|
return audio
|
||||||
|
elif audio.dtype == np.float32:
|
||||||
|
# Multiply by 32767 and not 32768 so that int16 doesn't overflow.
|
||||||
|
return (np.clip(audio, -1, 1) * 32767).astype(np.int16)
|
||||||
|
else:
|
||||||
|
raise TypeError(f"Unsupported audio data type: {audio.dtype}")
|
||||||
|
|
||||||
|
|
||||||
|
def play_audio(audio: np.ndarray, sample_rate: int):
|
||||||
|
# Requires the Portaudio library which might not be available in all environments.
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
with sd.OutputStream(samplerate=sample_rate, blocksize=1920, channels=1):
|
||||||
|
sd.play(audio, sample_rate)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Run Kyutai TTS using the PyTorch implementation"
|
||||||
|
)
|
||||||
|
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||||
|
parser.add_argument(
|
||||||
|
"out", type=str, help="Output file to generate, use - for playing the audio"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--hf-repo",
|
||||||
|
type=str,
|
||||||
|
default=DEFAULT_DSM_TTS_REPO,
|
||||||
|
help="HF repo in which to look for the pretrained models.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice-repo",
|
||||||
|
default=DEFAULT_DSM_TTS_VOICE_REPO,
|
||||||
|
help="HF repo in which to look for pre-computed voice embeddings.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--voice",
|
||||||
|
default="expresso/ex03-ex01_happy_001_channel1_334s.wav",
|
||||||
|
help="The voice to use, relative to the voice repo root. "
|
||||||
|
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print("Loading model...")
|
||||||
|
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||||
|
tts_model = TTSModel.from_checkpoint_info(
|
||||||
|
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.inp == "-":
|
||||||
|
if sys.stdin.isatty(): # Interactive
|
||||||
|
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||||
|
text = sys.stdin.read().strip()
|
||||||
|
else:
|
||||||
|
with open(args.inp, "r") as fobj:
|
||||||
|
text = fobj.read().strip()
|
||||||
|
|
||||||
|
# You could also generate multiple audios at once by passing a list of texts.
|
||||||
|
entries = tts_model.prepare_script([text], padding_between=1)
|
||||||
|
voice_path = tts_model.get_voice_path(args.voice)
|
||||||
|
# CFG coef goes here because the model was trained with CFG distillation,
|
||||||
|
# so it's not _actually_ doing CFG at inference time.
|
||||||
|
condition_attributes = tts_model.make_condition_attributes(
|
||||||
|
[voice_path], cfg_coef=2.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Generating audio...")
|
||||||
|
# This doesn't do streaming generation,
|
||||||
|
result = tts_model.generate([entries], [condition_attributes])
|
||||||
|
|
||||||
|
frames = torch.cat(result.frames, dim=-1)
|
||||||
|
audio_tokens = frames[:, tts_model.lm.audio_offset :, tts_model.delay_steps :]
|
||||||
|
with torch.no_grad():
|
||||||
|
audios = tts_model.mimi.decode(audio_tokens)
|
||||||
|
|
||||||
|
if args.out == "-":
|
||||||
|
print("Playing audio...")
|
||||||
|
play_audio(audios[0][0].cpu().numpy(), tts_model.mimi.sample_rate)
|
||||||
|
else:
|
||||||
|
sphn.write_wav(args.out, audios[0].cpu().numpy(), tts_model.mimi.sample_rate)
|
||||||
|
print(f"Audio saved to {args.out}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user