diff --git a/README.md b/README.md index 0daabaa..79901e9 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,9 @@ echo "Hey, how are you?" | python scripts/tts_pytorch.py - - # From text file to audio file python scripts/tts_pytorch.py text_to_say.txt audio_output.wav + +# Use --cpu flag for CPU-only inference +python scripts/tts_pytorch.py --cpu text_to_say.txt audio_output.wav ``` This requires the [moshi package](https://pypi.org/project/moshi/), which can be installed via pip. diff --git a/scripts/tts_pytorch.py b/scripts/tts_pytorch.py index 9230319..750cbbd 100644 --- a/scripts/tts_pytorch.py +++ b/scripts/tts_pytorch.py @@ -44,12 +44,26 @@ def main(): help="The voice to use, relative to the voice repo root. " f"See {DEFAULT_DSM_TTS_VOICE_REPO}", ) + parser.add_argument( + "--cpu", + action="store_true", + help="Use CPU instead of GPU for inference", + ) args = parser.parse_args() print("Loading model...") checkpoint_info = CheckpointInfo.from_hf_repo(args.hf_repo) + + # Set device and precision + if args.cpu: + device = torch.device("cpu") + dtype = torch.float32 + else: + device = torch.device("cuda") + dtype = torch.bfloat16 + tts_model = TTSModel.from_checkpoint_info( - checkpoint_info, n_q=32, temp=0.6, device=torch.device("cuda") + checkpoint_info, n_q=32, temp=0.6, device=device, dtype=dtype ) if args.inp == "-":