compile_onnx_models_lack_of_parameters.py
import onnx
import tvm
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx # Correct import path
def compile_model(onnx_path, target="llvm"):
# 1. Load ONNX model
onnx_model = onnx.load(onnx_path)
# 2. Convert to Relax IR (updated API)
mod = from_onnx(onnx_model, {"input_ids": (1, 1), "encoder_hidden_states": (1, 1500, 384)})# give input shape of both encoder and decoder, make them static. Somer op does not support dynamic shape
# 3. Apply mandatory passes
seq = tvm.ir.transform.Sequential([
relax.transform.LegalizeOps(),
relax.transform.FoldConstant(),
relax.transform.DeadCodeElimination()
])
mod = seq(mod)
# 4. Build
ex = relax.build(mod, target)
# 5. Save
output_path = onnx_path.replace(".onnx", ".so")
ex.export_library(output_path)
return output_path
# Compile both encoder and decoder
#encoder_so = compile_model("encoder_model_fp16.onnx", target="llvm")
decoder_so = compile_model("decoder_model_fp16.onnx", target="llvm")
run_whisper_tvm.py
import numpy as np
import tvm
from tvm import relax, runtime
from tvm.relax import VirtualMachine
from transformers import WhisperProcessor, WhisperTokenizer
import torchaudio
# === 初始化空的 16 個 KV(給 prefill 和 step-by-step 共用)===
def init_zero_past_kv(num_layers=4, num_heads=6, head_dim=64,
decoder_seq_len=0, encoder_seq_len=1500, dtype="float32"):
shape_decoder = (1, num_heads, decoder_seq_len, head_dim)
shape_encoder = (1, num_heads, encoder_seq_len, head_dim)
kvs = []
for _ in range(num_layers):
kvs += [
tvm.nd.array(np.zeros(shape_decoder, dtype=dtype)), # self.key
tvm.nd.array(np.zeros(shape_decoder, dtype=dtype)), # self.value
tvm.nd.array(np.zeros(shape_encoder, dtype=dtype)), # cross.key
tvm.nd.array(np.zeros(shape_encoder, dtype=dtype)) # cross.value
]
return kvs
# === 載入模型與 tokenizer ===
processor = WhisperProcessor.from_pretrained("./")
tokenizer = WhisperTokenizer.from_pretrained("./")
# === 音訊轉 mel spectrogram ===
waveform, sr = torchaudio.load("audio.wav")
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
waveform = waveform.mean(dim=0, keepdim=True)
inputs = processor(waveform.squeeze().numpy(), sampling_rate=16000, return_tensors="np")
mel = inputs.input_features.astype("float32")
# === Encoder ===
encoder_vm = VirtualMachine(runtime.load_module("./onnx/encoder_model_fp16.so"), tvm.cpu())
encoder_out = encoder_vm["main"](tvm.nd.array(mel)) # shape: (1, 1500, 384)
# === Decoder Step 0: Prefill ===
decoder_prefill_vm = VirtualMachine(runtime.load_module("./onnx/decoder_model_fp16.so"), tvm.cpu())
start_token = 50257 # error
eos_token = tokenizer.eos_token_id
tokens = [start_token]
input_ids = np.array([[start_token]], dtype="int64")
# 初始化全空 KV(self + cross)傳給 prefill decoder
past_kvs = init_zero_past_kv()
#inputs = [tvm.nd.array(input_ids), encoder_out] + past_kvs
inputs = [tvm.nd.array(input_ids), encoder_out]
print("\\n=== Step 0 (Prefill) ===")
out = decoder_prefill_vm["main"](*inputs)
logits = out[0].numpy()
next_token = int(np.argmax(logits[0, -1]))
assert logits.ndim == 2 or logits.ndim == 3, "logits 維度不符"
tokens.append(next_token)
print(f"⬆️ Next token: {next_token} ({tokenizer.decode([next_token])})")
# 將 decoder 回傳的 16 個 KV 擷取出來
decoder_kvs = list(out[1:]) # out[1]~out[16]
if next_token == eos_token:
print("🛑 遇到 <eos>,結束解碼")
transcript = tokenizer.decode(tokens, skip_special_tokens=True)
print("\\n📝 Transcription:\\n", transcript)
exit()
# === Decoder Step 1~N: step-by-step 解碼 ===
decoder_vm = VirtualMachine(runtime.load_module("./onnx/decoder_with_past_model_fp16.so"), tvm.cpu())
max_length = 64
for step in range(1, max_length):
print(f"\\n=== Step {step} ===")
input_ids = np.array([[tokens[-1]]], dtype="int64")
inputs = [tvm.nd.array(input_ids)] + decoder_kvs
out = decoder_vm["main"](*inputs)
logits = out[0].numpy()
next_token = int(np.argmax(logits[0, -1]))
tokens.append(next_token)
print(f"⬆️ Next token: {next_token} ({tokenizer.decode([next_token])})")
if next_token == eos_token:
print("🛑 遇到 <eos>,結束解碼")
break
# 只更新 self-attn 的位置(index 0,1,4,5,8,9,12,13)
for i, dst_idx in enumerate([0,1,4,5,8,9,12,13]):
decoder_kvs[dst_idx] = out[i + 1]
# === 最後輸出結果 ===
transcript = tokenizer.decode(tokens, skip_special_tokens=True)
print("\\n📝 Transcription:\\n", transcript)
