135 lines
4.6 KiB
Python
135 lines
4.6 KiB
Python
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
import coremltools as ct
|
|
from coremltools.models import MLModel
|
|
from initialize import Result
|
|
|
|
|
|
class AcquisitionResultANECoreML(Result):
|
|
"""
|
|
Run GNSS acquisition using CoreML model optimized for ANE
|
|
(cos/sin computed outside the model)
|
|
"""
|
|
def __init__(self, settings, mlpackage_path="acq_2ms_ane_v2.mlpackage"):
|
|
super().__init__(settings)
|
|
print(f"🔹 Loading CoreML model from {mlpackage_path} ...")
|
|
self.mlmodel = MLModel(mlpackage_path, compute_units=ct.ComputeUnit.CPU_ONLY)
|
|
self._results = None
|
|
self._channels = None
|
|
|
|
def acquire(self, longSignal):
|
|
s = self._settings
|
|
N = s.samplesPerCode
|
|
F = int(round(s.acqSearchBand * 2) + 1)
|
|
|
|
# === prepare 2 ms signals ===
|
|
raw1, raw2 = longSignal[0:2*N], longSignal[2*N:4*N]
|
|
sigI1, sigQ1 = raw1[0::2], raw1[1::2]
|
|
sigI2, sigQ2 = raw2[0::2], raw2[1::2]
|
|
sig_real = np.stack([sigI1, sigI2], axis=0).astype(np.float32)
|
|
sig_imag = np.stack([sigQ1, sigQ2], axis=0).astype(np.float32)
|
|
|
|
ts = 1.0 / s.samplingFreq
|
|
phasePoints = (np.arange(N) * 2 * np.pi * ts).astype(np.float32)
|
|
freqBins = np.linspace(
|
|
s.IF - s.acqSearchBand / 2 * 1e3,
|
|
s.IF + s.acqSearchBand / 2 * 1e3,
|
|
F, dtype=np.float32
|
|
)
|
|
|
|
# === precompute cos/sin outside model ===
|
|
phase_matrix = freqBins[:, None] * phasePoints[None, :]
|
|
cosFN = np.cos(phase_matrix).astype(np.float32)
|
|
sinFN = np.sin(phase_matrix).astype(np.float32)
|
|
|
|
carrFreq = np.zeros(32)
|
|
codePhase = np.zeros(32)
|
|
peakMetric = np.zeros(32)
|
|
|
|
print("(")
|
|
for PRN in range(len(s.acqSatelliteList)):
|
|
ca1023 = s.generateCAcode(PRN)
|
|
tc = 1.0 / s.codeFreqBasis
|
|
idx = np.ceil(ts * np.arange(1, N + 1) / tc).astype(int) - 1
|
|
idx[-1] = 1022
|
|
caN = ca1023[idx].astype(np.float32)
|
|
|
|
inputs = {
|
|
"sig_real": sig_real,
|
|
"sig_imag": sig_imag,
|
|
"ca_code": caN,
|
|
"cosFN": cosFN,
|
|
"sinFN": sinFN,
|
|
}
|
|
|
|
out_dict = self.mlmodel.predict(inputs)
|
|
powerFN = list(out_dict.values())[0]
|
|
|
|
P = np.array(powerFN, dtype=np.float32)
|
|
# Lấy max giữa 2 ms → [F, N]
|
|
P = np.max(P, axis=0)
|
|
|
|
peak = np.max(P)
|
|
fIdx, cIdx = np.unravel_index(np.argmax(P), P.shape)
|
|
|
|
chipSpan = int(round(s.samplingFreq / s.codeFreqBasis))
|
|
mask = np.ones(N, bool)
|
|
mask[max(0, cIdx - chipSpan):min(N, cIdx + chipSpan)] = False
|
|
second = np.max(P[fIdx, mask])
|
|
ratio = peak / (second + 1e-9)
|
|
peakMetric[PRN] = ratio
|
|
|
|
if s.acqMinThreshold < ratio < s.acqMaxThreshold:
|
|
print(f"{PRN + 1:02d} ", end="")
|
|
carrFreq[PRN] = float(freqBins[fIdx])
|
|
codePhase[PRN] = cIdx
|
|
else:
|
|
print(". ", end="")
|
|
print(")\n")
|
|
|
|
self._results = np.core.records.fromarrays(
|
|
[carrFreq, codePhase, peakMetric],
|
|
names="carrFreq,codePhase,peakMetric"
|
|
)
|
|
|
|
@staticmethod
|
|
def plot_heatmap(power, freqBins, PRN):
|
|
plt.figure(figsize=(8, 5))
|
|
plt.imshow(
|
|
power,
|
|
origin="lower",
|
|
aspect="auto",
|
|
extent=[0, power.shape[1], freqBins[0] / 1e3, freqBins[-1] / 1e3],
|
|
cmap="viridis"
|
|
)
|
|
plt.colorbar(label="Correlation Power")
|
|
plt.xlabel("Code Phase (samples)")
|
|
plt.ylabel("Doppler (kHz)")
|
|
plt.title(f"Acquisition Heatmap (PRN {PRN + 1}) - ANE inference")
|
|
plt.show()
|
|
|
|
|
|
# ================================================================
|
|
# Quick demo
|
|
# ================================================================
|
|
if __name__ == "__main__":
|
|
class DummySettings:
|
|
def __init__(self):
|
|
self.samplingFreq = 4.096e6
|
|
self.codeFreqBasis = 1.023e6
|
|
self.acqSearchBand = 14
|
|
self.samplesPerCode = int(round(self.samplingFreq / (self.codeFreqBasis / 1023)))
|
|
self.acqSatelliteList = range(1, 5)
|
|
self.acqMinThreshold = 2.5
|
|
self.acqMaxThreshold = 100.0
|
|
self.IF = 0.0
|
|
|
|
def generateCAcode(self, prn):
|
|
np.random.seed(prn)
|
|
return 2 * (np.random.randint(0, 2, 1023) - 0.5)
|
|
|
|
s = DummySettings()
|
|
longSignal = np.random.randn(4 * s.samplesPerCode).astype(np.float32)
|
|
acq = AcquisitionResultANECoreML(s)
|
|
acq.acquire(longSignal)
|