問題1:吃不到下一個token
解法:要用第一個 token 來初始化 decoder 的記憶(decoder_with_past)
問題2:跑一次要12分鐘(正常)
code:
import numpy as np
import tvm
from tvm import relax, runtime
from tvm.relax import VirtualMachine
from transformers import WhisperProcessor, WhisperTokenizer
import torchaudio
# === 初始化空的 16 個 KV(給 prefill 和 step-by-step 共用)===
def init_zero_past_kv(num_layers=4, num_heads=6, head_dim=64,
decoder_seq_len=0, encoder_seq_len=1500, dtype="float32"):
shape_decoder = (1, num_heads, decoder_seq_len, head_dim)
shape_encoder = (1, num_heads, encoder_seq_len, head_dim)
kvs = []
for _ in range(num_layers):
kvs += [
tvm.nd.array(np.zeros(shape_decoder, dtype=dtype)), # self.key
tvm.nd.array(np.zeros(shape_decoder, dtype=dtype)), # self.value
tvm.nd.array(np.zeros(shape_encoder, dtype=dtype)), # cross.key
tvm.nd.array(np.zeros(shape_encoder, dtype=dtype)) # cross.value
]
return kvs
# === 載入模型與 tokenizer ===
processor = WhisperProcessor.from_pretrained("./")
tokenizer = WhisperTokenizer.from_pretrained("./")
# === 音訊轉 mel spectrogram ===
waveform, sr = torchaudio.load("audio.wav")
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
waveform = waveform.mean(dim=0, keepdim=True)
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="np")
mel = inputs.input_features.astype("float32")
# === Encoder ===
encoder_vm = VirtualMachine(runtime.load_module("./onnx/encoder_model_fp16.so"), tvm.cpu())
encoder_out = encoder_vm["main"](tvm.nd.array(mel)) # shape: (1, 1500, 384)
# === Decoder Step 0: Prefill ===
decoder_prefill_vm = VirtualMachine(runtime.load_module("./onnx/decoder_model_fp16.so"), tvm.cpu())
start_token = 50257
eos_token = tokenizer.eos_token_id
tokens = [start_token]
input_ids = np.array([[start_token]], dtype="int64")
# 初始化全空 KV(self + cross)傳給 prefill decoder
past_kvs = init_zero_past_kv()
inputs = [tvm.nd.array(input_ids), encoder_out] + past_kvs
print("\\n=== Step 0 (Prefill) ===")
out = decoder_prefill_vm["main"](*inputs)
logits = out[0].numpy()
next_token = int(np.argmax(logits[0, -1]))
assert logits.ndim == 2 or logits.ndim == 3, "logits 維度不符"
tokens.append(next_token)
print(f"⬆️ Next token: {next_token} ({tokenizer.decode([next_token])})")
# 將 decoder 回傳的 16 個 KV 擷取出來
decoder_kvs = list(out[1:]) # out[1]~out[16]
if next_token == eos_token:
print("🛑 遇到 <eos>,結束解碼")
transcript = tokenizer.decode(tokens, skip_special_tokens=True)
print("\\n📝 Transcription:\\n", transcript)
exit()
# === Decoder Step 1~N: step-by-step 解碼 ===
decoder_vm = VirtualMachine(runtime.load_module("./onnx/decoder_with_past_model_fp16.so"), tvm.cpu())
max_length = 64
for step in range(1, max_length):
print(f"\\n=== Step {step} ===")
input_ids = np.array([[tokens[-1]]], dtype="int64")
inputs = [tvm.nd.array(input_ids)] + decoder_kvs
out = decoder_vm["main"](*inputs)
logits = out[0].numpy()
next_token = int(np.argmax(logits[0, -1]))
tokens.append(next_token)
print(f"⬆️ Next token: {next_token} ({tokenizer.decode([next_token])})")
if next_token == eos_token:
print("🛑 遇到 <eos>,結束解碼")
break
# 只更新 self-attn 的位置(index 0,1,4,5,8,9,12,13)
for i, dst_idx in enumerate([0,1,4,5,8,9,12,13]):
decoder_kvs[dst_idx] = out[i + 1]
# === 最後輸出結果 ===
transcript = tokenizer.decode(tokens, skip_special_tokens=True)
print("\\n📝 Transcription:\\n", transcript)
output:
$ python3 run_whisper_tvm.py
🔧 初始化完成:KV 數量 = 16
=== Step 0 (Prefill) ===
📥 Decoder 輸入參數資訊:
[ 0] shape=(1, 1), dtype=int64
[ 1] shape=(1, 6, 0, 64), dtype=float32
[ 2] shape=(1, 6, 0, 64), dtype=float32
[ 3] shape=(1, 6, 1500, 64), dtype=float32
[ 4] shape=(1, 6, 1500, 64), dtype=float32
[ 5] shape=(1, 6, 0, 64), dtype=float32
[ 6] shape=(1, 6, 0, 64), dtype=float32
[ 7] shape=(1, 6, 1500, 64), dtype=float32
[ 8] shape=(1, 6, 1500, 64), dtype=float32
[ 9] shape=(1, 6, 0, 64), dtype=float32
[10] shape=(1, 6, 0, 64), dtype=float32
[11] shape=(1, 6, 1500, 64), dtype=float32
[12] shape=(1, 6, 1500, 64), dtype=float32
[13] shape=(1, 6, 0, 64), dtype=float32
[14] shape=(1, 6, 0, 64), dtype=float32
[15] shape=(1, 6, 1500, 64), dtype=float32
[16] shape=(1, 6, 1500, 64), dtype=float32
⬆️ Next token: 50362 (<|nocaptions|>)
=== Step 1 ===
📥 Decoder 輸入參數資訊:
[ 0] shape=(1, 1), dtype=int64
[ 1] shape=(1, 6, 1, 64), dtype=float32
[ 2] shape=(1, 6, 1, 64), dtype=float32
[ 3] shape=(1, 6, 1500, 64), dtype=float32
[ 4] shape=(1, 6, 1500, 64), dtype=float32
[ 5] shape=(1, 6, 1, 64), dtype=float32
[ 6] shape=(1, 6, 1, 64), dtype=float32
[ 7] shape=(1, 6, 1500, 64), dtype=float32
[ 8] shape=(1, 6, 1500, 64), dtype=float32
[ 9] shape=(1, 6, 1, 64), dtype=float32
[10] shape=(1, 6, 1, 64), dtype=float32
[11] shape=(1, 6, 1500, 64), dtype=float32
[12] shape=(1, 6, 1500, 64), dtype=float32
[13] shape=(1, 6, 1, 64), dtype=float32
[14] shape=(1, 6, 1, 64), dtype=float32
[15] shape=(1, 6, 1500, 64), dtype=float32
[16] shape=(1, 6, 1500, 64), dtype=float32
✅ Decoder 輸出 9 個欄位
⬆️ Next token: 50257 (<|endoftext|>)
🛑 遇到 <eos>,結束解碼
📝 Transcription: