Use bfloat16 rather than half by default.
This commit is contained in:
parent
f9739881e6
commit
bfc200f6ee
|
|
@ -49,7 +49,7 @@ def main():
|
||||||
print("Loading model...")
|
print("Loading model...")
|
||||||
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||||
tts_model = TTSModel.from_checkpoint_info(
|
tts_model = TTSModel.from_checkpoint_info(
|
||||||
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
|
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda")
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.inp == "-":
|
if args.inp == "-":
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@
|
||||||
"# Set everything up\n",
|
"# Set everything up\n",
|
||||||
"checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
|
"checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
|
||||||
"tts_model = TTSModel.from_checkpoint_info(\n",
|
"tts_model = TTSModel.from_checkpoint_info(\n",
|
||||||
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n",
|
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\")\n",
|
||||||
")\n",
|
")\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
|
"# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user