Use the proper repos when vad is on. (#103)
This commit is contained in:
parent
af2283de3f
commit
07729ed47e
|
|
@ -24,13 +24,18 @@ if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("in_file", help="The file to transcribe.")
|
parser.add_argument("in_file", help="The file to transcribe.")
|
||||||
parser.add_argument("--max-steps", default=4096)
|
parser.add_argument("--max-steps", default=4096)
|
||||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
parser.add_argument("--hf-repo")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
audio, _ = sphn.read(args.in_file, sample_rate=24000)
|
||||||
|
if args.hf_repo is None:
|
||||||
|
if args.vad:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
else:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
|
||||||
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||||
with open(lm_config, "r") as fobj:
|
with open(lm_config, "r") as fobj:
|
||||||
lm_config = json.load(fobj)
|
lm_config = json.load(fobj)
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,9 @@ def tokens_to_timestamped_text(
|
||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
|
if args.vad and args.hf_repo is None:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
|
||||||
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
|
info = moshi.models.loaders.CheckpointInfo.from_hf_repo(
|
||||||
args.hf_repo,
|
args.hf_repo,
|
||||||
moshi_weights=args.moshi_weight,
|
moshi_weights=args.moshi_weight,
|
||||||
|
|
|
||||||
|
|
@ -25,12 +25,17 @@ from moshi_mlx import models, utils
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--max-steps", default=4096)
|
parser.add_argument("--max-steps", default=4096)
|
||||||
parser.add_argument("--hf-repo", default="kyutai/stt-1b-en_fr-mlx")
|
parser.add_argument("--hf-repo")
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
"--vad", action="store_true", help="Enable VAD (Voice Activity Detection)."
|
||||||
)
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.hf_repo is None:
|
||||||
|
if args.vad:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-candle"
|
||||||
|
else:
|
||||||
|
args.hf_repo = "kyutai/stt-1b-en_fr-mlx"
|
||||||
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
lm_config = hf_hub_download(args.hf_repo, "config.json")
|
||||||
with open(lm_config, "r") as fobj:
|
with open(lm_config, "r") as fobj:
|
||||||
lm_config = json.load(fobj)
|
lm_config = json.load(fobj)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user