Use a streaming input in the rust example. (#102)
* Use a streaming input in the rust example. * Formatting. * Another formatting tweak.
This commit is contained in:
parent
7dc926d50c
commit
af2283de3f
|
|
@ -89,6 +89,45 @@ async def output_audio(out: str, output_queue: asyncio.Queue[np.ndarray | None])
|
|||
print(f"Saved audio to {out}")
|
||||
|
||||
|
||||
async def read_lines_from_stdin():
|
||||
reader = asyncio.StreamReader()
|
||||
protocol = asyncio.StreamReaderProtocol(reader)
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.connect_read_pipe(lambda: protocol, sys.stdin)
|
||||
while True:
|
||||
line = await reader.readline()
|
||||
if not line:
|
||||
break
|
||||
yield line.decode().rstrip()
|
||||
|
||||
|
||||
async def read_lines_from_file(path: str):
|
||||
queue = asyncio.Queue()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def producer():
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
asyncio.run_coroutine_threadsafe(queue.put(line), loop)
|
||||
asyncio.run_coroutine_threadsafe(queue.put(None), loop)
|
||||
|
||||
await asyncio.to_thread(producer)
|
||||
while True:
|
||||
line = await queue.get()
|
||||
if line is None:
|
||||
break
|
||||
yield line
|
||||
|
||||
|
||||
async def get_lines(source: str):
|
||||
if source == "-":
|
||||
async for line in read_lines_from_stdin():
|
||||
yield line
|
||||
else:
|
||||
async for line in read_lines_from_file(source):
|
||||
yield line
|
||||
|
||||
|
||||
async def websocket_client():
|
||||
parser = argparse.ArgumentParser(description="Use the TTS streaming API")
|
||||
parser.add_argument("inp", type=str, help="Input file, use - for stdin.")
|
||||
|
|
@ -113,25 +152,26 @@ async def websocket_client():
|
|||
uri = f"{args.url}/api/tts_streaming?{urlencode(params)}"
|
||||
print(uri)
|
||||
|
||||
# TODO: stream the text instead of sending it all at once
|
||||
if args.inp == "-":
|
||||
if sys.stdin.isatty(): # Interactive
|
||||
print("Enter text to synthesize (Ctrl+D to end input):")
|
||||
text_to_tts = sys.stdin.read().strip()
|
||||
else:
|
||||
with open(args.inp, "r") as fobj:
|
||||
text_to_tts = fobj.read().strip()
|
||||
|
||||
headers = {"kyutai-api-key": args.api_key}
|
||||
|
||||
async with websockets.connect(uri, additional_headers=headers) as websocket:
|
||||
await websocket.send(msgpack.packb({"type": "Text", "text": text_to_tts}))
|
||||
print("connected")
|
||||
|
||||
async def send_loop():
|
||||
print("go send")
|
||||
async for line in get_lines(args.inp):
|
||||
for word in line.split():
|
||||
await websocket.send(msgpack.packb({"type": "Text", "text": word}))
|
||||
await websocket.send(msgpack.packb({"type": "Eos"}))
|
||||
|
||||
output_queue = asyncio.Queue()
|
||||
receive_task = asyncio.create_task(receive_messages(websocket, output_queue))
|
||||
output_audio_task = asyncio.create_task(output_audio(args.out, output_queue))
|
||||
await asyncio.gather(receive_task, output_audio_task)
|
||||
send_task = asyncio.create_task(send_loop())
|
||||
await asyncio.gather(receive_task, output_audio_task, send_task)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user