diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index 568b9c8..f27faaf 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -44,12 +44,18 @@ def main(): help="The voice to use, relative to the voice repo root. " f"See {DEFAULT_DSM_TTS_VOICE_REPO}", ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device on which to run, defaults to 'cuda'.", + ) args = parser.parse_args() print("Loading model...") checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) tts_model = TTSModel.from_checkpoint_info( - checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda") + checkpoint_info, n_q=32, temp=0.6, device=args.device ) if args.inp == "-":