105 lines
4.1 KiB
Python
105 lines
4.1 KiB
Python
|
|
import torch
|
||
|
|
import torch.nn as nn
|
||
|
|
import numpy as np
|
||
|
|
import coremltools as ct
|
||
|
|
|
||
|
|
|
||
|
|
# ================================================================
|
||
|
|
# ANE-ready Acquisition model (cos/sin moved outside the network)
|
||
|
|
# ================================================================
|
||
|
|
class AcquisitionRealNet2ms(nn.Module):
|
||
|
|
def __init__(self, samplesPerCode: int, numBins: int):
|
||
|
|
super().__init__()
|
||
|
|
self.N = samplesPerCode
|
||
|
|
self.F = numBins
|
||
|
|
|
||
|
|
# === Precompute real DFT matrices ===
|
||
|
|
n = torch.arange(self.N).float()
|
||
|
|
k = n[:, None]
|
||
|
|
W = torch.exp(-2j * np.pi * k * n / self.N)
|
||
|
|
scale = 1024.0 # prevent fp16 underflow
|
||
|
|
|
||
|
|
self.register_buffer("W_real", (W.real * scale).to(torch.float32))
|
||
|
|
self.register_buffer("W_imag", (W.imag * scale).to(torch.float32))
|
||
|
|
self.register_buffer("invN", torch.tensor(1.0 / (float(self.N) * scale), dtype=torch.float32))
|
||
|
|
|
||
|
|
def forward(self, sig_real, sig_imag, ca_code, cosFN, sinFN):
|
||
|
|
"""
|
||
|
|
sig_real, sig_imag: [2, N]
|
||
|
|
ca_code: [N]
|
||
|
|
cosFN, sinFN: [F, N]
|
||
|
|
Output: [F, N]
|
||
|
|
"""
|
||
|
|
B, N, F = sig_real.shape[0], self.N, cosFN.shape[0]
|
||
|
|
|
||
|
|
# --- Doppler mixing (precomputed cos/sin) ---
|
||
|
|
sig_r = sig_real[:, None, :] # [B,1,N]
|
||
|
|
sig_i = sig_imag[:, None, :] # [B,1,N]
|
||
|
|
x_real = sig_r * cosFN[None, :, :] - sig_i * sinFN[None, :, :]
|
||
|
|
x_imag = sig_r * sinFN[None, :, :] + sig_i * cosFN[None, :, :]
|
||
|
|
|
||
|
|
# --- FFT via MatMul ---
|
||
|
|
x_r2 = x_real.reshape(B * F, N)
|
||
|
|
x_i2 = x_imag.reshape(B * F, N)
|
||
|
|
X_real = torch.matmul(x_r2, self.W_real) - torch.matmul(x_i2, self.W_imag)
|
||
|
|
X_imag = torch.matmul(x_r2, self.W_imag) + torch.matmul(x_i2, self.W_real)
|
||
|
|
|
||
|
|
# --- FFT(CA) and conj multiplication ---
|
||
|
|
if ca_code.ndim == 1:
|
||
|
|
ca_code = ca_code.unsqueeze(0)
|
||
|
|
CA_real = torch.matmul(ca_code, self.W_real)
|
||
|
|
CA_imag = torch.matmul(ca_code, self.W_imag)
|
||
|
|
CAc_real = CA_real.squeeze(0)
|
||
|
|
CAc_imag = -CA_imag.squeeze(0)
|
||
|
|
|
||
|
|
conv_real = X_real * CAc_real[None, :] - X_imag * CAc_imag[None, :]
|
||
|
|
conv_imag = X_real * CAc_imag[None, :] + X_imag * CAc_real[None, :]
|
||
|
|
|
||
|
|
# --- IFFT via MatMul ---
|
||
|
|
out_real = (torch.matmul(conv_real, self.W_real.T) + torch.matmul(conv_imag, self.W_imag.T)) * self.invN
|
||
|
|
out_imag = (-torch.matmul(conv_real, self.W_imag.T) + torch.matmul(conv_imag, self.W_real.T)) * self.invN
|
||
|
|
|
||
|
|
# reshape and compute power
|
||
|
|
out_real = out_real.view(B, F, N)
|
||
|
|
out_imag = out_imag.view(B, F, N)
|
||
|
|
power = out_real *out_real + out_imag * out_imag
|
||
|
|
return power
|
||
|
|
# return torch.max(power, dim=0).values # [F, N]
|
||
|
|
|
||
|
|
|
||
|
|
# ================================================================
|
||
|
|
# CoreML export
|
||
|
|
# ================================================================
|
||
|
|
def export_to_coreml(samplesPerCode=8192, numBins=29, save_path="acq_2ms_ane_v2.mlpackage"):
|
||
|
|
model = AcquisitionRealNet2ms(samplesPerCode, numBins).eval()
|
||
|
|
sigR = torch.randn(2, samplesPerCode)
|
||
|
|
sigI = torch.randn(2, samplesPerCode)
|
||
|
|
ca = torch.randn(samplesPerCode)
|
||
|
|
cosFN = torch.randn(numBins, samplesPerCode)
|
||
|
|
sinFN = torch.randn(numBins, samplesPerCode)
|
||
|
|
|
||
|
|
traced = torch.jit.trace(model, (sigR, sigI, ca, cosFN, sinFN))
|
||
|
|
mlmodel = ct.convert(
|
||
|
|
traced,
|
||
|
|
inputs=[
|
||
|
|
ct.TensorType(name="sig_real", shape=sigR.shape),
|
||
|
|
ct.TensorType(name="sig_imag", shape=sigI.shape),
|
||
|
|
ct.TensorType(name="ca_code", shape=ca.shape),
|
||
|
|
ct.TensorType(name="cosFN", shape=cosFN.shape),
|
||
|
|
ct.TensorType(name="sinFN", shape=sinFN.shape),
|
||
|
|
],
|
||
|
|
convert_to="mlprogram",
|
||
|
|
compute_units=ct.ComputeUnit.ALL,
|
||
|
|
minimum_deployment_target=ct.target.macOS14,
|
||
|
|
compute_precision=ct.precision.FLOAT16, # ✅ ANE preferred
|
||
|
|
)
|
||
|
|
mlmodel.save(save_path)
|
||
|
|
print(f"✅ Exported CoreML ANE model to {save_path}")
|
||
|
|
|
||
|
|
|
||
|
|
# ================================================================
|
||
|
|
# Test export
|
||
|
|
# ================================================================
|
||
|
|
if __name__ == "__main__":
|
||
|
|
export_to_coreml()
|