Add a device argument to the tts pytorch script. (#62)

This commit is contained in:
Laurent Mazare 2025-07-07 08:36:47 +02:00 committed by GitHub
parent f8e97aa4f3
commit 70500c620e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -44,12 +44,18 @@ def main():
help="The voice to use, relative to the voice repo root. " help="The voice to use, relative to the voice repo root. "
f"See {DEFAULT_DSM_TTS_VOICE_REPO}", f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
) )
parser.add_argument(
"--device",
type=str,
default="cuda",
help="Device on which to run, defaults to 'cuda'.",
)
args = parser.parse_args() args = parser.parse_args()
print("Loading model...") print("Loading model...")
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
tts_model = TTSModel.from_checkpoint_info( tts_model = TTSModel.from_checkpoint_info(
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda") checkpoint_info, n_q=32, temp=0.6, device=args.device
) )
if args.inp == "-": if args.inp == "-":