From ebbd58dd2398f6d6567cf4f5e6c1e0d43734254a Mon Sep 17 00:00:00 2001 From: Eugene Date: Wed, 2 Jul 2025 10:53:31 +0200 Subject: [PATCH] Text-audio prompt example into README.md + cutting prompt transcript. --- README.md | 16 ++++++++ audio/loona.mp3 | Bin 0 -> 9003 bytes ...cribe_from_file_via_pytorch_with_prompt.py | 35 ++++++++++++++++-- 3 files changed, 47 insertions(+), 4 deletions(-) create mode 100644 audio/loona.mp3 diff --git a/README.md b/README.md index b546f3b..b65b470 100644 --- a/README.md +++ b/README.md @@ -88,6 +88,22 @@ uv run scripts/evaluate_on_dataset.py \ --hf-repo kyutai/stt-2.6b-en ``` +Another example shows how one can provide a text-, audio-, or text-audio prompt to our STT model: +```bash +uv run scripts/transcribe_from_file_via_pytorch_with_prompt.py \ + --hf-repo kyutai/stt-2.6b-en \ + --file bria.mp3 \ + --prompt_file ./audio/loonah.mp3 \ + --prompt_text "Loonah" \ + --cut-prompt-transcript +``` +Produces the transcript of `bria.mp3` using the `Loonah` spelling for the name, instead of the `Luna` used without any prompt: +``` +In the heart of an ancient forest, where the trees whispered secrets of the past, there lived a peculiar rabbit named Loonah (...) +``` +Please bear in mind that is an experimental feature and its behavior is very sensitive to the prompt provided. + + ### Rust server diff --git a/audio/loona.mp3 b/audio/loona.mp3 new file mode 100644 index 0000000000000000000000000000000000000000..997cc312ea21a7fa7310d07b3d1303925eee2a36 GIT binary patch literal 9003 zcmc(lWl$W^5~vqWOvOs_!LBrziZo%C(5FmJPcPGJ}5C}o{L2kWQ z@BjUAyJ}{7YHGUsJKb~6%sHbd!-oiX5smrluAa z7S`5&{``4xaCUZeb#-?KgS~iuj?v;dMjnxulm9F9NSpsrK`^aGW`%z`>|VVe06^m> z*Wyor!sX#ql=5D{0pQ`_?5B7v&7749nDX_j$!&hhYoj9kxbUU^q>hU+d2rKnCkl7- z1ofpxIS9m-c!SevDj~sh2SkO33#&echlh)%MwuLe2S5REZ*Tl={5Js62uoY6B^_qc zEZ@3pdG(JFrK?uq``@~DyR+xdJG zL2dL;sm^q>@M5Gu@Ha0Gv zi$~So&PE!CU*Gcyc7InrY#=1Q2`lYNWE&vf=-Qu9XmujS5q=gLkKTjmC!M?IKO&H|pETV~6Eh-N=#QH&6Y#L|+@`YiHYofLsCY^$65XUFGlFwuU}v?Co@*Zv2I z85DCi>0-p6F6V)xbWgsM^0$c>eJ{wO zFxaX5@2*bb54;&!o}e%OfA`qo-M{xs4CrtJKz0DZwT@CwA?@&gKqvrA+%~fl`fACK z^=$5;Xu{}jspx|2`)h?4WgBy_*<^(uGR` zE)_3i&K)1!kW7^pE{u&EdcAsFu%gM@cg`X+13Fm2ODF_!^*)tu~tqn z*wM($sI-k*ro;TYI1IM_8cBXE^b0ccIHbGMzP7uXS(HO==wHb}F%Bav*y&Ynt#}*` zgV6p6NjLXV+w<&OC_k~P6Nn;R8N0P%bCF%rd;jWTw>hj$J~%&hw=}zx9Uf1nS2(iY+5Zr8h5YZ*)1H#EQ%7vFu9v#whOPG2LB>VujuPOw9s)!SQ&%^TR2 z_j*hEUSX0=wK(iitMW?%dESI*B(&Ti`B;Yerzd%x1lGk15DEa}k=jZ=UHz9ck3D5@ z1h4m?LSFiElNvQA`&@ECLlC9pneFB9-j9G2_j5-m?l#+-Vp>S$#~%sNxaLBzJpr5q z+={^wsjl0u)rwc7DQgv|1p&^<{;)0egqi+e2~de9E`W1owZB5ngw8vN_A6AtLyL%% zx1JTzQ)alXLZe^Ha4y|IzlLQ?dfg}O4g&z(`DVGblusA3XC={uPuG< zGe2tnySaOw_+$SXXqIhD42uu?Y%TKrl<^KSotB4pX*Ol_f;Ipx$V!dHl zD`!P@9oRpSU&?qrlN~rRJs<5-IQ4eYWHr%E1SlyxBHjEGy;{$0^KmRhl97mc*dDA6 z44(0_f7*PUcG~|%``#TO)=ah*L)*VeWxx@uW3HTp`hFRkH<6cAiMH)3#foV)Q9P2f zGDomg*IC(Y(N6WbBM~Mm_5a2}h=?3C$is6@VTL?Kw}vjd`K0@p%c!0MMq@=P<(09e zx7jQCVXdr?OwC1J(Lg0%zUJvLJ~>DOK3(oq@W6QiB0#{7Fi>?w6B{;TISPSdWZ>+| z8^dbV+P<90%1iA|&Tf9++)v&8aBC#~mPd`;f8|ktAv{=%*kT`V?C-w zWCeXpX9c1)0N)9&lWd)fF=?ZxCZkuwPy6%gbYVcEd$kiMq!g{}okkHvF?#zA0=LdHrgt*2=> z063V!FPDZm_7if^%BZ&!&(G+0#=iiOKykA>ZQ8=1c+z7P2%;NV7 z&w0bCtz&t}!_?sV)&2W!0g3oc@OPrs?!|if3s*v&&b+jV)?>_x#wj=e5I=t#e;)zA zJ|flLw}P5BI;hoKtKP56K?FPSU8je=9;{a3Df*_LI4@$I_HpwCS3(Sd0%i6cSfJ*>U}d~l(E)T~?viqupLcqCnOU_l(#?S=RG+)OR{ z8w~cqj(miq3VVnE=p&&P8R5UYCx)v;h;=IfQ0du&XrqGxpN&+Px1z|F_pofvMh|$_ zti-JcMwf-9$tleDNf(6ySaIS}O!X++$C|RSjNx*XG3MBgPSw*7o<~FRZ+^!;!C;M* zAXETA(-f=h*}A1E-~|X7fN@f+*A@YVb0I#yA-O;RAUVZKYJY6145S!xXb7zev>>Bh zR`K4VE5!7T;7QPD_lzc$lo7x!2N_F*qHS^2BeUc5gVUI!_`!?}8&Q4;1Vc%5A2q@L zBfnvVrP)IB@qm-to3U+3<#($Jl0l!&`kX=4-1mWrId<>^yF zg1oeXXB*?}mv!1Kg@lu<5hhAmbKTATH!4F^XAqHO;Igo&nR+WH3yCr4=C_*otK57d%~u1&DXnNkju1m2Su&K_Q$k=D=L zt8P!Bqhfel|3ta!Jyj{D9ZMMLWdgIgk%9|{Uy`I;XsAOLg$3kMk@nQk|D`xynRXjc$YiVEj576s0B*v7X|tsvS1nCF8yNv5e(U?&b`cO*Qoy{jGlu4vR^;-MGL+Xh)nqRQg- z2)=r*B3f0Ke4wxMpI8S`H&_4XxOW0(81xs>G^f~a{?g4HqycnC$nj?Pw0OyK6iRYA ztyWXmRwbmvJF4C0_<~+@{MJHANc->4ge&2hf8<3NOJySc8Dc-IF~EZX;59}tBJx8 z?U3)0$=-vo8v&ISB1pwd+0VY|9jQP~I=RG;4>~>N^YF>|&*kV0r_lYuOo&V)qUo&{ zwl*L2bnz&LVZb;=1y3md_1S#$Or36hn~90{zl5!X_m z7g%4&FOZ=FTJN)16;F^!Cman@@rIxWubOGFC*U@e6A45Xi*56f-X>4guda=EA z6Js|YXqd(Z4eyH%*20SAJSDYZXAP%UOFMANGI+()GPj4bTT8~d;+5!cz~6FnUoylK zgl1_*qQc$RgGmY+D8$tNl%KBLiDLnA*_At=yCoSF*c2F7O6?7bukXh&bX_a0y*%YB zyn^W|-R@Sv)d)%`4n`+XiJcK^3L4R!+iUKd{dhN!Jce`hilJf51Zbp__Qn;q*^tf2 zLZ5nQao%{Ye}1FqnsWc~>GT{zgjZE59sm80eZ3;Qm>#(MCuRPR%8Nz%R0F%NHO6p* z(LTb=hpxNdj@B2M*{H=!^CO*74pRf5=2Z;CgT0L`arbwpx93F1HxknE>qbTn!R^4hLSH#jEwukG;8=;r8qnjvmn-kDxYWau6;l{$ca^|RFi zn)-Fg_Y!&XyCX8Qp;_sVDNgn-72W*x12gbJQK#J1QQo*xcd5F27qUdJDG)HS)hYB@ z;hbrRO<4z#&5g22L4+*F9;!o4Q( zybqV?ln4N8uRrS*^i@{3`)&j%$f0D&-2iVe9*MX8Mui=O-T&i_n_cx|dr`s*kU2am zy9LP@qD>DH3LRQnH13_4sYzaI9UoM$gO;xw=?!0ig8%->9g0ZWMk{{Tu!agzn;NL} zwwbD|T+34V&iL}aLpM$hW(9nNSsfe06<^}?sM>JIZH2ud0JILL+`iXHTyF;zJbj47 zq9Aovzc7=@oV#0A5zGcPlS|6^pC6vNe=$i6uiYgY2l9+B4ZB3qNDykiop!7-^_x^n zVeuujM`R=k=i$xy{R$+@_uH&UvM{Ui%;qUh4;Ikke|{Gaj|zyTOK{Jp7JC8u6o~5P zPBvC#E^(M(kr9Nc;J{G9$+L5ql~Un470Y|M0~BVbWNsIs_%j?S9`oTG&BD8*UI)EB zC7`Y?Z~-YOG>V+&^|Acj{4A{U-dL6}UYdZ+Jq(W|VI@s72eT+-n>-pN$h_A3A{D(1 zyo2`o;iye=G=D&e*E6Fii8HVDx`P#hh*pB>H=(1=cvtPVFm5u@R8iDC4-h>ZC1ZK6 z#j3#0nF&hXU3781#Y{zI#qua;+*z9VU(RKV=*csyqoCYUwRYH491QmSg|6A36Ns{7 zN1_}p4_(ABV!ntGE@^PwDNR0Zyd)*1F^+#l{#8MJ3B<&N@s&e%B;)AYo{y51i>rb&DyMY;{<<|)FWd-S1J z`=T-+_Ij!dHATwT<7 zYBR0DEze&fK4xevQYKEBYUC%z8a(83hN&ysyGxwRwH`JI3x##oCLkBL`!j@s<2v6_ zBNu`((6hU@D%nvNn9m0I%%o&AK|d4RJ`YSVU^LFw+E)tip>7k+9nHM@I8S93#b+&$ z_%#WsDNa02*#|kH^sMfwcK6dm?K=~Ho|?O=xrFQ0>uy%`KAbO9PWNh?FF*?b{xIUIoqc!5W)a|VJlogb2P1?IIrrj zlz(Ctub^O#TytO^%B_DLp^&pn6 zF&~DK6rdGic}b5tjZ?ATcCv9yMVA$5ddgFC65oKaw|hfmBvg3#Vv(q=T$Dsxf>=nz zEGJ!m)fC5wtBpMb@Z$2KMGAdc{>|xq0dOf9%zS9k3YBZr0i8~rUUea1`d@cVS)mlPE zmY|vgxTqtaPbelw&fL%sS$?V9sDoVWU5fw6YUrhc11!0fpVvf^0 z_GsankLRUduqjZ*7YbXSvo=y9nnNqNX$XpFc;cHsmV}@7>7Kl)6JO3M=(=v0NR;;2 ziAF@#+O%ZWjX?u+#Kk_hRdh(SAp-77S--qV zqsADI>SHD33~*g6Xt-ag3qwiutm2{tRMb8gYU*4kC*fxo%qsOFz5x9U!qZ7ZRiPgp z;ceOqD2nxReVd_2Q0a(Yz;_lBA}`CB96>T7g{Fc*YgnvfSE*VjiRjI?K#!;_L**oE zW>XTDSHp4AAEq3Xr5o5mbQBXII@^&|`0{D(cSz>&A$)t#vTC1X6oSOr>2`HN4V5-CjhToR2aICFnYsV1cEX zqTb^%U$5KK#MXhmnNRfz7c-sL^|NcK4(pYDyt930ckjR5`e1&40SZRI9gaeN1{z`K z&l_&a@cc9tLbaady6u+Y9Ic|vg9)|cSP%cnFYzf8VIIfbrirR`#a)^#yOg*mr+E;3QA3-djz3}BQ>iT@A(!oqEBn{?qCwP(x zTlIxqyf)Gq|F#5Qs1?G_fLCSdpg+Oi;`E0xa=q&8%2)JQwc9ExXbp8k5l^hCq$Ph6 z+sO4pT+nQ_KVONc5?_NOGEk1@hbj^=)vKy zpnTa95>{xDgc2{Bya?&_Kl)eBCp!fV{R>bBx#wo?79rAi6?TYooWF%B>^`m(qH{i1~opQRM3^5AM`Ny zxi`=lOE?Vmon2z5pum~6deLO4Ux}Ao82AhmTbN$v2NRDPY30n5@|`Y}fjeoDrhQ=5 z9m9v}16z%$7W?Y28&-PlDhK1ZAkhJ}P?edvBYms9F_6QA-wTitJeuz!P#ICSM<@-1 z7ZweeyS9zD=QV75D}oxq;PHLBkv$#Hh4nm+*tK_9#JS=^^&WZ8fRqsD^P%x#Vl8*N zeqC#2yg^k>Im^DwM9sV&D@2-bb?i5Ks(zEyYXe4=Hl`i>7a&Ijbed&cc6w7umiU4Q zX!Hdn^EZo>R+LBs(UKp@9g8#CH3k$r=u`0+{SaCdWWnov{ZXqWIO%KiTsgih20$g% zf{@gt{fs2^Wd6q6?_Pj+BXE*3SzT~aje;a3A;I^v3u}_$4%$ZC8CTGx~$vi+fRrR#`FyPKqXf|xb zb3(VMsbi;FaX#<2;#lRY+-M9tE_obnn@#gmxFs+=`Zo@xLy?=%sF{jP;Y8Cwgn7{y znCR4Gm+d42SF(z5`YOYPd@~x1fBLf16UZ!U+$Q>;62yCl9C)NxP8FmPk)<2Emu9;G zrc0#o*@jG=Dvld9_sgS2%f*zCBAbgGF)_2eb6qxXW>0Ebma-cXQi>^ z%^_7%6Lx1&a90bJSq5a#gMfHhXc6IoJ^M)dgUmAQzxoovI`6L_vcns?_3_!)I`b)j zbmlWt_x@CRjw>mlS08ryCm|FxNwt0Siop^>ZkwbCFBxirL+9_ot=ck`G|vJJClvi< zJZH-#`U*TwjlIURbunieM216inUmN>>l8n|St?nVCRjddFLY)vv`t+OMPbZ>EQFHp z)s-;4ajvq4(5vTfS>PiwBO+Q0nn(9&v*xLj^ycSpl{?5dz~`YVh+eGKfRw^F+F>Wp zGcJ#QBkF@>+|IsFARoE2P0a$;*eLD5=Q-+agV#fAw!E+_rT2F0(!ytbz4 zW{7cGllre~E05dW=fLFGOiH_CwgvLb;7zy0073tjQ#u0BWNLv`!B#yIMCP&qn9e|P z-Ebaep7rhe{5VZ|NVIjn{M&$fQf=qUfDRhd@RMj+@P*xuQbBaflfX`nHd*zgX7X;i zh4+Rk`nm@vGpV^riry|)6-j+~I;}=y7!wW>9!Yl`I(Sgj ztv2;n7?L&$Y`}(Y(C~1oLn^5byO}b zc4#?RX2Bpd`nMfav02s@X9zRPt91-8`IvRlD)exm@wLh*Ku1|pMV9a7{{fxX@~PrS#HoP)WxM~Z-Yx6l)2#N$|BJNZe=09Pdx7X< xJ=mD(<^eXJ&6a*+D*jI+`Y)Zo01XBWJ2Xu)6`6-WPoqNt0L1^&=Ku8e{{!~8JbC~C literal 0 HcmV?d00001 diff --git a/scripts/transcribe_from_file_via_pytorch_with_prompt.py b/scripts/transcribe_from_file_via_pytorch_with_prompt.py index 833bb8d..5861116 100644 --- a/scripts/transcribe_from_file_via_pytorch_with_prompt.py +++ b/scripts/transcribe_from_file_via_pytorch_with_prompt.py @@ -14,7 +14,15 @@ import tqdm class PromptHook: - def __init__(self, tokenizer, prefix, padding_tokens=(0, 3,)): + def __init__( + self, + tokenizer, + prefix, + padding_tokens=( + 0, + 3, + ), + ): self.tokenizer = tokenizer self.prefix_enforce = deque(self.tokenizer.encode(prefix)) self.padding_tokens = padding_tokens @@ -102,10 +110,12 @@ def main(args): chain = [itertools.repeat(silence_chunk, n_prefix_chunks)] if audio_prompt is not None: - chain.append(torch.split(audio_prompt[:, None], mimi.frame_size, dim=-1)) + chain.append(torch.split(audio_prompt[:, None, :], mimi.frame_size, dim=-1)) + # adding a bit (0.8s) of silence to separate prompt and the actual audio + chain.append(itertools.repeat(silence_chunk, 10)) chain += [ - torch.split(audio[:, None], mimi.frame_size, dim=-1), + torch.split(audio[:, None, :], mimi.frame_size, dim=-1), itertools.repeat(silence_chunk, n_suffix_chunks), ] @@ -121,9 +131,22 @@ def main(args): utterance_tokens = torch.concat(text_tokens_accum, dim=-1) text_tokens = utterance_tokens.cpu().view(-1) + + # if we have an audio prompt and we don't want to have it in the transcript, + # we should cut the corresponding number of frames from the output tokens. + # However, there is also some amount of padding that happens before it + # due to silence_prefix and audio_delay. Normally it is ignored in detokenization, + # but now we should account for it to find the position of the prompt transcript. + if args.cut_prompt_transcript and audio_prompt is not None: + prompt_frames = audio_prompt.shape[1] // mimi.frame_size + no_prompt_offset_seconds = audio_delay_seconds + audio_silence_prefix_seconds + no_prompt_offset = int(no_prompt_offset_seconds * mimi.frame_rate) + text_tokens = text_tokens[prompt_frames + no_prompt_offset:] + text = tokenizer.decode( text_tokens[text_tokens > padding_token_id].numpy().tolist() ) + print(text) @@ -144,7 +167,11 @@ if __name__ == "__main__": required=False, help="Text of the prompt.", ) - + parser.add_argument( + "--cut-prompt-transcript", + action="store_true", + help="Cut the prompt from the output transcript", + ) parser.add_argument( "--hf-repo", type=str, help="HF repo to load the STT model from. " )