Compare commits

...

3 Commits

Author SHA1 Message Date
laurent
a1af3a78cb Another formatting tweak. 2025-07-31 17:40:35 +02:00
laurent
ecc002f7af Formatting. 2025-07-31 17:38:34 +02:00
laurent
94e271b69a Use a streaming input in the rust example. 2025-07-31 17:37:43 +02:00

View File

@ -89,6 +89,45 @@ async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None])
print(f"Saved audio to {out}") print(f"Saved audio to {out}")
async def read_lines_from_stdin():
reader = asyncio.StreamReader()
protocol = asyncio.StreamReaderProtocol(reader)
loop = asyncio.get_running_loop()
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
while True:
line = await reader.readline()
if not line:
break
yield line.decode().rstrip()
async def read_lines_from_file(path: str):
queue = asyncio.Queue()
loop = asyncio.get_running_loop()
def producer():
with open(path, "r", encoding="utf-8") as f:
for line in f:
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
await asyncio.to_thread(producer)
while True:
line = await queue.get()
if line is None:
break
yield line
async def get_lines(source: str):
if source == "-":
async for line in read_lines_from_stdin():
yield line
else:
async for line in read_lines_from_file(source):
yield line
async def websocket_client(): async def websocket_client():
parser = argparse.ArgumentParser(description="Use the TTS streaming API") parser = argparse.ArgumentParser(description="Use the TTS streaming API")
parser.add_argument("inp", type=str, help="Input file, use - for stdin.") parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
@ -113,25 +152,26 @@ async def websocket_client():
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}" uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
print(uri) print(uri)
# TODO: stream the text instead of sending it all at once
if args.inp == "-": if args.inp == "-":
if sys.stdin.isatty(): # Interactive if sys.stdin.isatty(): # Interactive
print("Enter text to synthesize (Ctrl+D to end input):") print("Enter text to synthesize (Ctrl+D to end input):")
text_to_tts = sys.stdin.read().strip()
else:
with open(args.inp, "r") as fobj:
text_to_tts = fobj.read().strip()
headers = {"kyutai-api-key": args.api_key} headers = {"kyutai-api-key": args.api_key}
async with websockets.connect(uri, additional_headers=headers) as websocket: async with websockets.connect(uri, additional_headers=headers) as websocket:
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts})) print("connected")
async def send_loop():
print("go send")
async for line in get_lines(args.inp):
for word in line.split():
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
await websocket.send(msgpack.packb({"type": "Eos"})) await websocket.send(msgpack.packb({"type": "Eos"}))
output_queue = asyncio.Queue() output_queue = asyncio.Queue()
receive_task = asyncio.create_task(receive_messages(websocket, output_queue)) receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue)) output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
await asyncio.gather(receive_task, output_audio_task) send_task = asyncio.create_task(send_loop())
await asyncio.gather(receive_task, output_audio_task, send_task)
if __name__ == "__main__": if __name__ == "__main__":