From 8927c0d9eb7065db87d4c3bf95212bd08958f203 Mon Sep 17 00:00:00 2001 From: laurent Date: Mon, 7 Jul 2025 08:31:03 +0200 Subject: [PATCH] Add a device argument to the tts pytorch script. --- scripts/tts_pytorch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index 568b9c8..f27faaf 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -44,12 +44,18 @@ def main(): help="The voice to use, relative to the voice repo root. " f"See {DEFAULT_DSM_TTS_VOICE_REPO}", ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device on which to run, defaults to 'cuda'.", + ) args = parser.parse_args() print("Loading model...") checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) tts_model = TTSModel.from_checkpoint_info( - checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda") + checkpoint_info, n_q=32, temp=0.6, device=args.device ) if args.inp == "-":