diff --git a/scripts/tts_rust_server.py b/scripts/tts_rust_server.py index 9b67dc8..77ec367 100644 --- a/scripts/tts_rust_server.py +++ b/scripts/tts_rust_server.py @@ -89,6 +89,45 @@ async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None]) print(f"Saved audio to {out}") +async def read_lines_from_stdin(): + reader = asyncio.StreamReader() + protocol = asyncio.StreamReaderProtocol(reader) + loop = asyncio.get_running_loop() + await loop.connect_read_pipe(lambda: protocol, sys.stdin) + while True: + line = await reader.readline() + if not line: + break + yield line.decode().rstrip() + + +async def read_lines_from_file(path: str): + queue = asyncio.Queue() + loop = asyncio.get_running_loop() + + def producer(): + with open(path, "r", encoding="utf-8") as f: + for line in f: + asyncio.run_coroutine_threadsafe(queue.put(line), loop) + asyncio.run_coroutine_threadsafe(queue.put(None), loop) + + await asyncio.to_thread(producer) + while True: + line = await queue.get() + if line is None: + break + yield line + + +async def get_lines(source: str): + if source == "-": + async for line in read_lines_from_stdin(): + yield line + else: + async for line in read_lines_from_file(source): + yield line + + async def websocket_client(): parser = argparse.ArgumentParser(description="Use the TTS streaming API") parser.add_argument("inp", type=str, help="Input file, use - for stdin.") @@ -113,25 +152,26 @@ async def websocket_client(): uri = f"{args.url}/api/tts_streaming?{urlencode(params)}" print(uri) - # TODO: stream the text instead of sending it all at once if args.inp == "-": if sys.stdin.isatty(): # Interactive print("Enter text to synthesize (Ctrl+D to end input):") - text_to_tts = sys.stdin.read().strip() - else: - with open(args.inp, "r") as fobj: - text_to_tts = fobj.read().strip() - headers = {"kyutai-api-key": args.api_key} async with websockets.connect(uri, additional_headers=headers) as websocket: - await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts})) - await websocket.send(msgpack.packb({"type": "Eos"})) + print("connected") + + async def send_loop(): + print("go send") + async for line in get_lines(args.inp): + for word in line.split(): + await websocket.send(msgpack.packb({"type": "Text", "text": word})) + await websocket.send(msgpack.packb({"type": "Eos"})) output_queue = asyncio.Queue() receive_task = asyncio.create_task(receive_messages(websocket, output_queue)) output_audio_task = asyncio.create_task(output_audio(args.out, output_queue)) - await asyncio.gather(receive_task, output_audio_task) + send_task = asyncio.create_task(send_loop()) + await asyncio.gather(receive_task, output_audio_task, send_task) if __name__ == "__main__":