diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index f513aaa..9230319 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -49,7 +49,7 @@ def main(): print("Loading model...") checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) tts_model = TTSModel.from_checkpoint_info( - checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda"), dtype=torch.half + checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda") ) if args.inp == "-": diff --git a/tts_pytorch.ipynb b/tts_pytorch.ipynb index 9680bdd..1a206c0 100644 --- a/tts_pytorch.ipynb +++ b/tts_pytorch.ipynb @@ -55,7 +55,7 @@ "# Set everything up\n", "checkpoint_info = CheckpointInfo.from_hf_repo(DEFAULT_DSM_TTS_REPO)\n", "tts_model = TTSModel.from_checkpoint_info(\n", - " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\"), dtype=torch.half\n", + " checkpoint_info, n_q=32, temp=0.6, device=torch.device(\"cuda\")\n", ")\n", "\n", "# If you want to make a dialog, you can pass more than one turn [text_speaker_1, text_speaker_2, text_2_speaker_1, ...]\n",