From 70500c620e444bc0892bc5e08038c8d22d3f14ad Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Mon, 7 Jul 2025 08:36:47 +0200 Subject: [PATCH] Add a device argument to the tts pytorch script. (#62) --- scripts/tts_pytorch.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index 568b9c8..f27faaf 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -44,12 +44,18 @@ def main(): help="The voice to use, relative to the voice repo root. " f"See {DEFAULT_DSM_TTS_VOICE_REPO}", ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device on which to run, defaults to 'cuda'.", + ) args = parser.parse_args() print("Loading model...") checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) tts_model = TTSModel.from_checkpoint_info( - checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda") + checkpoint_info, n_q=32, temp=0.6, device=args.device ) if args.inp == "-":