Add a device argument to the tts pytorch script. (#62)
This commit is contained in:
parent
f8e97aa4f3
commit
70500c620e
|
|
@ -44,12 +44,18 @@ def main():
|
||||||
help="The voice to use, relative to the voice repo root. "
|
help="The voice to use, relative to the voice repo root. "
|
||||||
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
f"See {DEFAULT_DSM_TTS_VOICE_REPO}",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--device",
|
||||||
|
type=str,
|
||||||
|
default="cuda",
|
||||||
|
help="Device on which to run, defaults to 'cuda'.",
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
print("Loading model...")
|
print("Loading model...")
|
||||||
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||||
tts_model = TTSModel.from_checkpoint_info(
|
tts_model = TTSModel.from_checkpoint_info(
|
||||||
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda")
|
checkpoint_info, n_q=32, temp=0.6, device=args.device
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.inp == "-":
|
if args.inp == "-":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user