SoftGNSS/acquisition_torch.py

193 lines
7.7 KiB
Python
Raw Normal View History

2025-10-22 16:08:12 +07:00
import torch
import torch.fft as fft
import numpy as np
import matplotlib.pyplot as plt
from initialize import Result
class AcquisitionResultTorch(Result):
def __init__(self, settings, device="cuda"):
super().__init__(settings)
self.device = device
self._results = None
self._channels = None
# ================================================================
def acquire(self, longSignal):
"""
GPU/NPU-based FFT acquisition using PyTorch batch processing.
Equivalent to SoftGNSS acquisition.m but vectorized across Doppler bins.
"""
settings = self._settings
device = self.device
dtype_c = torch.complex64
dtype_r = torch.float32
samplesPerCode = settings.samplesPerCode
# --- Prepare two ms of signal (IQ interleaved) --------------------------
rawSignal1 = longSignal[0:2 * samplesPerCode]
signalI1 = rawSignal1[0::2]
signalQ1 = rawSignal1[1::2]
signal1 = torch.tensor(signalI1 + 1j * signalQ1, dtype=dtype_c, device=device)
rawSignal2 = longSignal[2 * samplesPerCode:4 * samplesPerCode]
signalI2 = rawSignal2[0::2]
signalQ2 = rawSignal2[1::2]
signal2 = torch.tensor(signalI2 + 1j * signalQ2, dtype=dtype_c, device=device)
# DC removal (long segment)
longI = longSignal[0::2]
longQ = longSignal[1::2]
longIQ = torch.tensor(longI + 1j * longQ, dtype=dtype_c, device=device)
signal0DC = longIQ - torch.mean(longIQ)
# --- Generate frequency bins -------------------------------------------
numberOfFrqBins = int(round(settings.acqSearchBand * 2) + 1)
frqBins = torch.linspace(
settings.IF - settings.acqSearchBand / 2 * 1e3,
settings.IF + settings.acqSearchBand / 2 * 1e3,
numberOfFrqBins, device=device, dtype=dtype_r
)
# --- Time phase array ---------------------------------------------------
ts = 1.0 / settings.samplingFreq
phasePoints = torch.arange(samplesPerCode, device=device, dtype=dtype_r) * 2 * np.pi * ts
# --- Generate CA code table (on CPU, then move one by one) -------------
caCodesTable = settings.makeCaTable()
carrFreq = np.zeros(32)
codePhase = np.zeros(32)
peakMetric = np.zeros(32)
print("(")
for PRN in range(len(settings.acqSatelliteList)):
# --- Prepare CA code ------------------------------------------------
caCode = torch.tensor(caCodesTable[PRN, :], dtype=dtype_r, device=device)
caCodeFreqDom = fft.fft(caCode).conj()
# --- Generate carrier matrix [N, M] --------------------------------
carr = torch.exp(1j * (frqBins[:, None] * phasePoints[None, :])) # (N, M)
# --- Mix and FFT ----------------------------------------------------
IQ1 = carr * signal1
IQ2 = carr * signal2
IQfreqDom1 = fft.fft(IQ1, dim=1)
IQfreqDom2 = fft.fft(IQ2, dim=1)
conv1 = IQfreqDom1 * caCodeFreqDom[None, :]
conv2 = IQfreqDom2 * caCodeFreqDom[None, :]
acqRes1 = torch.abs(fft.ifft(conv1, dim=1)) ** 2
acqRes2 = torch.abs(fft.ifft(conv2, dim=1)) ** 2
# --- Combine results (choose stronger between 1st & 2nd ms) --------
results = torch.maximum(acqRes1, acqRes2)
# --- Peak detection -------------------------------------------------
max_per_bin, _ = torch.max(results, dim=1)
frequencyBinIndex = int(torch.argmax(max_per_bin).cpu())
peakSize = float(torch.max(results).cpu())
codePhaseIdx = int(torch.argmax(torch.max(results, dim=0).values).cpu())
samplesPerCodeChip = int(round(settings.samplingFreq / settings.codeFreqBasis))
exclude1 = codePhaseIdx - samplesPerCodeChip
exclude2 = codePhaseIdx + samplesPerCodeChip
if exclude1 <= 0:
codePhaseRange = torch.arange(exclude2, samplesPerCode + exclude1 + 1) % samplesPerCode
elif exclude2 >= samplesPerCode - 1:
codePhaseRange = torch.arange(exclude2 - samplesPerCode, exclude1) % samplesPerCode
else:
codePhaseRange = torch.cat([
torch.arange(0, exclude1 + 1),
torch.arange(exclude2, samplesPerCode)
])
secondPeakSize = float(results[frequencyBinIndex, codePhaseRange.long()].max().cpu())
ratio = peakSize / (secondPeakSize + 1e-12)
peakMetric[PRN] = ratio
if settings.acqMinThreshold < ratio < settings.acqMaxThreshold:
print(f"{PRN + 1:02d} ", end="")
# --- Fine Doppler refinement using 10 ms ------------------------
caCode_cpu = settings.generateCAcode(PRN)
caCode_torch = torch.tensor(caCode_cpu, dtype=dtype_r, device=device)
codeValueIndex = torch.floor(
ts * torch.arange(1, 10 * samplesPerCode + 1, device=device) / (1.0 / settings.codeFreqBasis)
).long() % 1023
longCaCode = caCode_torch[codeValueIndex]
xCarrier = signal0DC[codePhaseIdx:codePhaseIdx + 10 * samplesPerCode] * longCaCode
fftNumPts = int(8 * 2 ** np.ceil(np.log2(len(xCarrier))))
fftxc = torch.abs(fft.fft(xCarrier, fftNumPts)) # spectrum
fftMaxIdx = int(torch.argmax(fftxc).cpu())
fftFreqBins = torch.arange(fftNumPts, device=device, dtype=dtype_r) * settings.samplingFreq / fftNumPts
dopFreq = float(fftFreqBins[fftMaxIdx].cpu())
if fftMaxIdx > (fftNumPts + 1) / 2.0:
dopFreq -= settings.samplingFreq
carrFreq[PRN] = dopFreq
codePhase[PRN] = codePhaseIdx
# self.plot_acquisition_3d(results.detach().cpu().numpy(),
# frqBins.detach().cpu().numpy(),
# PRN, settings)
else:
print(". ", end="")
print(")\n")
acqResults = np.core.records.fromarrays(
[carrFreq, codePhase, peakMetric],
names="carrFreq,codePhase,peakMetric"
)
self._results = acqResults
return
# ================================================================
def plot(self):
"""Bar plot of acquisition metric for all PRNs."""
assert isinstance(self._results, np.recarray)
plt.figure()
plt.bar(range(1, 33), self._results.peakMetric)
plt.title("Acquisition Results (GPU-PyTorch)")
plt.xlabel("PRN Number")
plt.ylabel("Peak Metric (1st/2nd)")
plt.grid(True)
plt.show()
# ================================================================
@staticmethod
def plot_acquisition_3d(results, frqBins, PRN, settings):
"""
Plot 3D surface acquisition result for one PRN.
results: matrix [numberOfFrqBins x samplesPerCode]
frqBins: list of Doppler frequencies (Hz)
"""
from mpl_toolkits.mplot3d import Axes3D # noqa: F401
codePhases = np.arange(results.shape[1])
dopplers = np.array(frqBins)
X, Y = np.meshgrid(codePhases, dopplers)
Z = results
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(X, Y / 1000.0, Z, cmap='viridis', linewidth=0, antialiased=False)
fig.colorbar(surf, shrink=0.5, aspect=5)
ax.set_title(f'Acquisition Surface (PRN {PRN + 1})')
ax.set_xlabel('Code Phase (samples)')
ax.set_ylabel('Doppler (kHz)')
ax.set_zlabel('Correlation Power')
plt.show()