import torch import torch.nn as nn import numpy as np import coremltools as ct # ================================================================ # ANE-ready Acquisition model (cos/sin moved outside the network) # ================================================================ class AcquisitionRealNet2ms(nn.Module): def __init__(self, samplesPerCode: int, numBins: int): super().__init__() self.N = samplesPerCode self.F = numBins # === Precompute real DFT matrices === n = torch.arange(self.N).float() k = n[:, None] W = torch.exp(-2j * np.pi * k * n / self.N) scale = 1024.0 # prevent fp16 underflow self.register_buffer("W_real", (W.real * scale).to(torch.float32)) self.register_buffer("W_imag", (W.imag * scale).to(torch.float32)) self.register_buffer("invN", torch.tensor(1.0 / (float(self.N) * scale), dtype=torch.float32)) def forward(self, sig_real, sig_imag, ca_code, cosFN, sinFN): """ sig_real, sig_imag: [2, N] ca_code: [N] cosFN, sinFN: [F, N] Output: [F, N] """ B, N, F = sig_real.shape[0], self.N, cosFN.shape[0] # --- Doppler mixing (precomputed cos/sin) --- sig_r = sig_real[:, None, :] # [B,1,N] sig_i = sig_imag[:, None, :] # [B,1,N] x_real = sig_r * cosFN[None, :, :] - sig_i * sinFN[None, :, :] x_imag = sig_r * sinFN[None, :, :] + sig_i * cosFN[None, :, :] # --- FFT via MatMul --- x_r2 = x_real.reshape(B * F, N) x_i2 = x_imag.reshape(B * F, N) X_real = torch.matmul(x_r2, self.W_real) - torch.matmul(x_i2, self.W_imag) X_imag = torch.matmul(x_r2, self.W_imag) + torch.matmul(x_i2, self.W_real) # --- FFT(CA) and conj multiplication --- if ca_code.ndim == 1: ca_code = ca_code.unsqueeze(0) CA_real = torch.matmul(ca_code, self.W_real) CA_imag = torch.matmul(ca_code, self.W_imag) CAc_real = CA_real.squeeze(0) CAc_imag = -CA_imag.squeeze(0) conv_real = X_real * CAc_real[None, :] - X_imag * CAc_imag[None, :] conv_imag = X_real * CAc_imag[None, :] + X_imag * CAc_real[None, :] # --- IFFT via MatMul --- out_real = (torch.matmul(conv_real, self.W_real.T) + torch.matmul(conv_imag, self.W_imag.T)) * self.invN out_imag = (-torch.matmul(conv_real, self.W_imag.T) + torch.matmul(conv_imag, self.W_real.T)) * self.invN # reshape and compute power out_real = out_real.view(B, F, N) out_imag = out_imag.view(B, F, N) power = out_real *out_real + out_imag * out_imag return power # return torch.max(power, dim=0).values # [F, N] # ================================================================ # CoreML export # ================================================================ def export_to_coreml(samplesPerCode=8192, numBins=29, save_path="acq_2ms_ane_v2.mlpackage"): model = AcquisitionRealNet2ms(samplesPerCode, numBins).eval() sigR = torch.randn(2, samplesPerCode) sigI = torch.randn(2, samplesPerCode) ca = torch.randn(samplesPerCode) cosFN = torch.randn(numBins, samplesPerCode) sinFN = torch.randn(numBins, samplesPerCode) traced = torch.jit.trace(model, (sigR, sigI, ca, cosFN, sinFN)) mlmodel = ct.convert( traced, inputs=[ ct.TensorType(name="sig_real", shape=sigR.shape), ct.TensorType(name="sig_imag", shape=sigI.shape), ct.TensorType(name="ca_code", shape=ca.shape), ct.TensorType(name="cosFN", shape=cosFN.shape), ct.TensorType(name="sinFN", shape=sinFN.shape), ], convert_to="mlprogram", compute_units=ct.ComputeUnit.ALL, minimum_deployment_target=ct.target.macOS14, compute_precision=ct.precision.FLOAT16, # ✅ ANE preferred ) mlmodel.save(save_path) print(f"✅ Exported CoreML ANE model to {save_path}") # ================================================================ # Test export # ================================================================ if __name__ == "__main__": export_to_coreml()