SoftGNSS/acquisition_ane_create.py

105 lines
4.1 KiB
Python
Raw Permalink Normal View History

2025-10-22 16:08:12 +07:00
import torch
import torch.nn as nn
import numpy as np
import coremltools as ct
# ================================================================
# ANE-ready Acquisition model (cos/sin moved outside the network)
# ================================================================
class AcquisitionRealNet2ms(nn.Module):
def __init__(self, samplesPerCode: int, numBins: int):
super().__init__()
self.N = samplesPerCode
self.F = numBins
# === Precompute real DFT matrices ===
n = torch.arange(self.N).float()
k = n[:, None]
W = torch.exp(-2j * np.pi * k * n / self.N)
scale = 1024.0 # prevent fp16 underflow
self.register_buffer("W_real", (W.real * scale).to(torch.float32))
self.register_buffer("W_imag", (W.imag * scale).to(torch.float32))
self.register_buffer("invN", torch.tensor(1.0 / (float(self.N) * scale), dtype=torch.float32))
def forward(self, sig_real, sig_imag, ca_code, cosFN, sinFN):
"""
sig_real, sig_imag: [2, N]
ca_code: [N]
cosFN, sinFN: [F, N]
Output: [F, N]
"""
B, N, F = sig_real.shape[0], self.N, cosFN.shape[0]
# --- Doppler mixing (precomputed cos/sin) ---
sig_r = sig_real[:, None, :] # [B,1,N]
sig_i = sig_imag[:, None, :] # [B,1,N]
x_real = sig_r * cosFN[None, :, :] - sig_i * sinFN[None, :, :]
x_imag = sig_r * sinFN[None, :, :] + sig_i * cosFN[None, :, :]
# --- FFT via MatMul ---
x_r2 = x_real.reshape(B * F, N)
x_i2 = x_imag.reshape(B * F, N)
X_real = torch.matmul(x_r2, self.W_real) - torch.matmul(x_i2, self.W_imag)
X_imag = torch.matmul(x_r2, self.W_imag) + torch.matmul(x_i2, self.W_real)
# --- FFT(CA) and conj multiplication ---
if ca_code.ndim == 1:
ca_code = ca_code.unsqueeze(0)
CA_real = torch.matmul(ca_code, self.W_real)
CA_imag = torch.matmul(ca_code, self.W_imag)
CAc_real = CA_real.squeeze(0)
CAc_imag = -CA_imag.squeeze(0)
conv_real = X_real * CAc_real[None, :] - X_imag * CAc_imag[None, :]
conv_imag = X_real * CAc_imag[None, :] + X_imag * CAc_real[None, :]
# --- IFFT via MatMul ---
out_real = (torch.matmul(conv_real, self.W_real.T) + torch.matmul(conv_imag, self.W_imag.T)) * self.invN
out_imag = (-torch.matmul(conv_real, self.W_imag.T) + torch.matmul(conv_imag, self.W_real.T)) * self.invN
# reshape and compute power
out_real = out_real.view(B, F, N)
out_imag = out_imag.view(B, F, N)
power = out_real *out_real + out_imag * out_imag
return power
# return torch.max(power, dim=0).values # [F, N]
# ================================================================
# CoreML export
# ================================================================
def export_to_coreml(samplesPerCode=8192, numBins=29, save_path="acq_2ms_ane_v2.mlpackage"):
model = AcquisitionRealNet2ms(samplesPerCode, numBins).eval()
sigR = torch.randn(2, samplesPerCode)
sigI = torch.randn(2, samplesPerCode)
ca = torch.randn(samplesPerCode)
cosFN = torch.randn(numBins, samplesPerCode)
sinFN = torch.randn(numBins, samplesPerCode)
traced = torch.jit.trace(model, (sigR, sigI, ca, cosFN, sinFN))
mlmodel = ct.convert(
traced,
inputs=[
ct.TensorType(name="sig_real", shape=sigR.shape),
ct.TensorType(name="sig_imag", shape=sigI.shape),
ct.TensorType(name="ca_code", shape=ca.shape),
ct.TensorType(name="cosFN", shape=cosFN.shape),
ct.TensorType(name="sinFN", shape=sinFN.shape),
],
convert_to="mlprogram",
compute_units=ct.ComputeUnit.ALL,
minimum_deployment_target=ct.target.macOS14,
compute_precision=ct.precision.FLOAT16, # ✅ ANE preferred
)
mlmodel.save(save_path)
print(f"✅ Exported CoreML ANE model to {save_path}")
# ================================================================
# Test export
# ================================================================
if __name__ == "__main__":
export_to_coreml()