Use bfloat16 rather than half by default.
This commit is contained in:
parent
f9739881e6
commit
bfc200f6ee
|
|
@ -49,7 +49,7 @@ def main():
|
|||
print("Loading model...")
|
||||
checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo)
|
||||
tts_model = TTSModel.from_checkpoint_info(
|
||||
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half
|
||||
checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda")
|
||||
)
|
||||
|
||||
if args.inp == "-":
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@
|
|||
"# Set everything up\n",
|
||||
"checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n",
|
||||
"tts_model = TTSModel.from_checkpoint_info(\n",
|
||||
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n",
|
||||
" checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\")\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user