問題1:吃不到下一個token

解法:要用第一個 token 來初始化 decoder 的記憶(decoder_with_past)

問題2:跑一次要12分鐘(正常)

code:

import numpy as np
import tvm
from tvm import relax, runtime
from tvm.relax import VirtualMachine
from transformers import WhisperProcessor, WhisperTokenizer
import torchaudio

# === 初始化空的 16 個 KV(給 prefill 和 step-by-step 共用)===
def init_zero_past_kv(num_layers=4, num_heads=6, head_dim=64,
                      decoder_seq_len=0, encoder_seq_len=1500, dtype="float32"):
    shape_decoder = (1, num_heads, decoder_seq_len, head_dim)
    shape_encoder = (1, num_heads, encoder_seq_len, head_dim)
    kvs = []
    for _ in range(num_layers):
        kvs += [
            tvm.nd.array(np.zeros(shape_decoder, dtype=dtype)),  # self.key
            tvm.nd.array(np.zeros(shape_decoder, dtype=dtype)),  # self.value
            tvm.nd.array(np.zeros(shape_encoder, dtype=dtype)),  # cross.key
            tvm.nd.array(np.zeros(shape_encoder, dtype=dtype))   # cross.value
        ]
    return kvs

# === 載入模型與 tokenizer ===
processor = WhisperProcessor.from_pretrained("./")
tokenizer = WhisperTokenizer.from_pretrained("./")

# === 音訊轉 mel spectrogram ===
waveform, sr = torchaudio.load("audio.wav")
if sr != 16000:
    waveform = torchaudio.functional.resample(waveform, sr, 16000)
waveform = waveform.mean(dim=0, keepdim=True)
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="np")
mel = inputs.input_features.astype("float32")

# === Encoder ===
encoder_vm = VirtualMachine(runtime.load_module("./onnx/encoder_model_fp16.so"), tvm.cpu())
encoder_out = encoder_vm["main"](tvm.nd.array(mel))  # shape: (1, 1500, 384)

# === Decoder Step 0: Prefill ===
decoder_prefill_vm = VirtualMachine(runtime.load_module("./onnx/decoder_model_fp16.so"), tvm.cpu())

start_token = 50257
eos_token = tokenizer.eos_token_id
tokens = [start_token]
input_ids = np.array([[start_token]], dtype="int64")

# 初始化全空 KV(self + cross)傳給 prefill decoder
past_kvs = init_zero_past_kv()
inputs = [tvm.nd.array(input_ids), encoder_out] + past_kvs

print("\\n=== Step 0 (Prefill) ===")
out = decoder_prefill_vm["main"](*inputs)

logits = out[0].numpy()
next_token = int(np.argmax(logits[0, -1]))

assert logits.ndim == 2 or logits.ndim == 3, "logits 維度不符"

tokens.append(next_token)
print(f"⬆️ Next token: {next_token} ({tokenizer.decode([next_token])})")

# 將 decoder 回傳的 16 個 KV 擷取出來
decoder_kvs = list(out[1:])  # out[1]~out[16]

if next_token == eos_token:
    print("🛑 遇到 <eos>,結束解碼")
    transcript = tokenizer.decode(tokens, skip_special_tokens=True)
    print("\\n📝 Transcription:\\n", transcript)
    exit()

# === Decoder Step 1~N: step-by-step 解碼 ===
decoder_vm = VirtualMachine(runtime.load_module("./onnx/decoder_with_past_model_fp16.so"), tvm.cpu())
max_length = 64

for step in range(1, max_length):
    print(f"\\n=== Step {step} ===")
    input_ids = np.array([[tokens[-1]]], dtype="int64")
    inputs = [tvm.nd.array(input_ids)] + decoder_kvs

    out = decoder_vm["main"](*inputs)
    logits = out[0].numpy()
    next_token = int(np.argmax(logits[0, -1]))
    tokens.append(next_token)
    print(f"⬆️ Next token: {next_token} ({tokenizer.decode([next_token])})")

    if next_token == eos_token:
        print("🛑 遇到 <eos>,結束解碼")
        break

    # 只更新 self-attn 的位置(index 0,1,4,5,8,9,12,13)
    for i, dst_idx in enumerate([0,1,4,5,8,9,12,13]):
        decoder_kvs[dst_idx] = out[i + 1]

# === 最後輸出結果 ===
transcript = tokenizer.decode(tokens, skip_special_tokens=True)
print("\\n📝 Transcription:\\n", transcript)

output:

$ python3 run_whisper_tvm.py 
🔧 初始化完成:KV 數量 = 16

=== Step 0 (Prefill) ===
📥 Decoder 輸入參數資訊:
  [ 0] shape=(1, 1), dtype=int64
  [ 1] shape=(1, 6, 0, 64), dtype=float32
  [ 2] shape=(1, 6, 0, 64), dtype=float32
  [ 3] shape=(1, 6, 1500, 64), dtype=float32
  [ 4] shape=(1, 6, 1500, 64), dtype=float32
  [ 5] shape=(1, 6, 0, 64), dtype=float32
  [ 6] shape=(1, 6, 0, 64), dtype=float32
  [ 7] shape=(1, 6, 1500, 64), dtype=float32
  [ 8] shape=(1, 6, 1500, 64), dtype=float32
  [ 9] shape=(1, 6, 0, 64), dtype=float32
  [10] shape=(1, 6, 0, 64), dtype=float32
  [11] shape=(1, 6, 1500, 64), dtype=float32
  [12] shape=(1, 6, 1500, 64), dtype=float32
  [13] shape=(1, 6, 0, 64), dtype=float32
  [14] shape=(1, 6, 0, 64), dtype=float32
  [15] shape=(1, 6, 1500, 64), dtype=float32
  [16] shape=(1, 6, 1500, 64), dtype=float32
⬆️ Next token: 50362 (<|nocaptions|>)

=== Step 1 ===
📥 Decoder 輸入參數資訊:
  [ 0] shape=(1, 1), dtype=int64
  [ 1] shape=(1, 6, 1, 64), dtype=float32
  [ 2] shape=(1, 6, 1, 64), dtype=float32
  [ 3] shape=(1, 6, 1500, 64), dtype=float32
  [ 4] shape=(1, 6, 1500, 64), dtype=float32
  [ 5] shape=(1, 6, 1, 64), dtype=float32
  [ 6] shape=(1, 6, 1, 64), dtype=float32
  [ 7] shape=(1, 6, 1500, 64), dtype=float32
  [ 8] shape=(1, 6, 1500, 64), dtype=float32
  [ 9] shape=(1, 6, 1, 64), dtype=float32
  [10] shape=(1, 6, 1, 64), dtype=float32
  [11] shape=(1, 6, 1500, 64), dtype=float32
  [12] shape=(1, 6, 1500, 64), dtype=float32
  [13] shape=(1, 6, 1, 64), dtype=float32
  [14] shape=(1, 6, 1, 64), dtype=float32
  [15] shape=(1, 6, 1500, 64), dtype=float32
  [16] shape=(1, 6, 1500, 64), dtype=float32
✅ Decoder 輸出 9 個欄位
⬆️ Next token: 50257 (<|endoftext|>)
🛑 遇到 <eos>,結束解碼

📝 Transcription: